1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A class to represent an relation over integer tuples. A relation is
10 // represented as a constraint system over a space of tuples of integer valued
11 // variables supporting symbolic variables and existential quantification.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/Presburger/IntegerRelation.h"
16 #include "mlir/Analysis/Presburger/LinearTransform.h"
17 #include "mlir/Analysis/Presburger/PWMAFunction.h"
18 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
19 #include "mlir/Analysis/Presburger/Simplex.h"
20 #include "mlir/Analysis/Presburger/Utils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "presburger"
26 
27 using namespace mlir;
28 using namespace presburger;
29 
30 using llvm::SmallDenseMap;
31 using llvm::SmallDenseSet;
32 
33 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
34   return std::make_unique<IntegerRelation>(*this);
35 }
36 
37 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
38   return std::make_unique<IntegerPolyhedron>(*this);
39 }
40 
41 void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
42   assert(space.getNumVars() == oSpace.getNumVars() && "invalid space!");
43   space = oSpace;
44 }
45 
46 void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
47   assert(oSpace.getNumLocalVars() == 0 && "no locals should be present!");
48   assert(oSpace.getNumVars() <= getNumVars() && "invalid space!");
49   unsigned newNumLocals = getNumVars() - oSpace.getNumVars();
50   space = oSpace;
51   space.insertVar(VarKind::Local, 0, newNumLocals);
52 }
53 
54 void IntegerRelation::append(const IntegerRelation &other) {
55   assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
56 
57   inequalities.reserveRows(inequalities.getNumRows() +
58                            other.getNumInequalities());
59   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
60 
61   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
62     addInequality(other.getInequality(r));
63   }
64   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
65     addEquality(other.getEquality(r));
66   }
67 }
68 
69 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const {
70   IntegerRelation result = *this;
71   result.mergeLocalVars(other);
72   result.append(other);
73   return result;
74 }
75 
76 bool IntegerRelation::isEqual(const IntegerRelation &other) const {
77   assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
78   return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
79 }
80 
81 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
82   assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
83   return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other));
84 }
85 
86 MaybeOptimum<SmallVector<Fraction, 8>>
87 IntegerRelation::findRationalLexMin() const {
88   assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
89   MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
90       LexSimplex(*this).findRationalLexMin();
91 
92   if (!maybeLexMin.isBounded())
93     return maybeLexMin;
94 
95   // The Simplex returns the lexmin over all the variables including locals. But
96   // locals are not actually part of the space and should not be returned in the
97   // result. Since the locals are placed last in the list of variables, they
98   // will be minimized last in the lexmin. So simply truncating out the locals
99   // from the end of the answer gives the desired lexmin over the dimensions.
100   assert(maybeLexMin->size() == getNumVars() &&
101          "Incorrect number of vars in lexMin!");
102   maybeLexMin->resize(getNumDimAndSymbolVars());
103   return maybeLexMin;
104 }
105 
106 MaybeOptimum<SmallVector<int64_t, 8>>
107 IntegerRelation::findIntegerLexMin() const {
108   assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
109   MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
110       LexSimplex(*this).findIntegerLexMin();
111 
112   if (!maybeLexMin.isBounded())
113     return maybeLexMin.getKind();
114 
115   // The Simplex returns the lexmin over all the variables including locals. But
116   // locals are not actually part of the space and should not be returned in the
117   // result. Since the locals are placed last in the list of variables, they
118   // will be minimized last in the lexmin. So simply truncating out the locals
119   // from the end of the answer gives the desired lexmin over the dimensions.
120   assert(maybeLexMin->size() == getNumVars() &&
121          "Incorrect number of vars in lexMin!");
122   maybeLexMin->resize(getNumDimAndSymbolVars());
123   return maybeLexMin;
124 }
125 
126 static bool rangeIsZero(ArrayRef<int64_t> range) {
127   return llvm::all_of(range, [](int64_t x) { return x == 0; });
128 }
129 
130 static void removeConstraintsInvolvingVarRange(IntegerRelation &poly,
131                                                unsigned begin, unsigned count) {
132   // We loop until i > 0 and index into i - 1 to avoid sign issues.
133   //
134   // We iterate backwards so that whether we remove constraint i - 1 or not, the
135   // next constraint to be tested is always i - 2.
136   for (unsigned i = poly.getNumEqualities(); i > 0; i--)
137     if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count)))
138       poly.removeEquality(i - 1);
139   for (unsigned i = poly.getNumInequalities(); i > 0; i--)
140     if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
141       poly.removeInequality(i - 1);
142 }
143 
144 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
145   return {getSpace(), getNumInequalities(), getNumEqualities()};
146 }
147 
148 void IntegerRelation::truncateVarKind(VarKind kind, unsigned num) {
149   unsigned curNum = getNumVarKind(kind);
150   assert(num <= curNum && "Can't truncate to more vars!");
151   removeVarRange(kind, num, curNum);
152 }
153 
154 void IntegerRelation::truncateVarKind(VarKind kind,
155                                       const CountsSnapshot &counts) {
156   truncateVarKind(kind, counts.getSpace().getNumVarKind(kind));
157 }
158 
159 void IntegerRelation::truncate(const CountsSnapshot &counts) {
160   truncateVarKind(VarKind::Domain, counts);
161   truncateVarKind(VarKind::Range, counts);
162   truncateVarKind(VarKind::Symbol, counts);
163   truncateVarKind(VarKind::Local, counts);
164   removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
165   removeEqualityRange(counts.getNumEqs(), getNumEqualities());
166 }
167 
168 PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
169   // If there are no locals, we're done.
170   if (getNumLocalVars() == 0)
171     return PresburgerRelation(*this);
172 
173   // Move all the non-div locals to the end, as the current API to
174   // SymbolicLexMin requires these to form a contiguous range.
175   //
176   // Take a copy so we can perform mutations.
177   IntegerRelation copy = *this;
178   std::vector<MaybeLocalRepr> reprs;
179   copy.getLocalReprs(reprs);
180 
181   // Iterate through all the locals. The last `numNonDivLocals` are the locals
182   // that have been scanned already and do not have division representations.
183   unsigned numNonDivLocals = 0;
184   unsigned offset = copy.getVarKindOffset(VarKind::Local);
185   for (unsigned i = 0, e = copy.getNumLocalVars(); i < e - numNonDivLocals;) {
186     if (!reprs[i]) {
187       // Whenever we come across a local that does not have a division
188       // representation, we swap it to the `numNonDivLocals`-th last position
189       // and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
190       copy.swapVar(offset + i, offset + e - numNonDivLocals - 1);
191       std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
192       ++numNonDivLocals;
193       continue;
194     }
195     ++i;
196   }
197 
198   // If there are no non-div locals, we're done.
199   if (numNonDivLocals == 0)
200     return PresburgerRelation(*this);
201 
202   // We computeSymbolicIntegerLexMin by considering the non-div locals as
203   // "non-symbols" and considering everything else as "symbols". This will
204   // compute a function mapping assignments to "symbols" to the
205   // lexicographically minimal valid assignment of "non-symbols", when a
206   // satisfying assignment exists. It separately returns the set of assignments
207   // to the "symbols" such that a satisfying assignment to the "non-symbols"
208   // exists but the lexmin is unbounded. We basically want to find the set of
209   // values of the "symbols" such that an assignment to the "non-symbols"
210   // exists, which is the union of the domain of the returned lexmin function
211   // and the returned set of assignments to the "symbols" that makes the lexmin
212   // unbounded.
213   SymbolicLexMin lexminResult =
214       SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
215                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
216                              /*numDims=*/copy.getNumVars() - numNonDivLocals)))
217           .computeSymbolicIntegerLexMin();
218   PresburgerRelation result =
219       lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
220 
221   // The result set might lie in the wrong space -- all its ids are dims.
222   // Set it to the desired space and return.
223   PresburgerSpace space = getSpace();
224   space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
225   result.setSpace(space);
226   return result;
227 }
228 
229 SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
230   // Symbol and Domain vars will be used as symbols for symbolic lexmin.
231   // In other words, for every value of the symbols and domain, return the
232   // lexmin value of the (range, locals).
233   llvm::SmallBitVector isSymbol(getNumVars(), false);
234   isSymbol.set(getVarKindOffset(VarKind::Symbol),
235                getVarKindEnd(VarKind::Symbol));
236   isSymbol.set(getVarKindOffset(VarKind::Domain),
237                getVarKindEnd(VarKind::Domain));
238   // Compute the symbolic lexmin of the dims and locals, with the symbols being
239   // the actual symbols of this set.
240   SymbolicLexMin result =
241       SymbolicLexSimplex(*this,
242                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
243                              /*numDims=*/getNumDomainVars(),
244                              /*numSymbols=*/getNumSymbolVars())),
245                          isSymbol)
246           .computeSymbolicIntegerLexMin();
247 
248   // We want to return only the lexmin over the dims, so strip the locals from
249   // the computed lexmin.
250   result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
251                                getNumLocalVars());
252   return result;
253 }
254 
255 PresburgerRelation
256 IntegerRelation::subtract(const PresburgerRelation &set) const {
257   return PresburgerRelation(*this).subtract(set);
258 }
259 
260 unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) {
261   assert(pos <= getNumVarKind(kind));
262 
263   unsigned insertPos = space.insertVar(kind, pos, num);
264   inequalities.insertColumns(insertPos, num);
265   equalities.insertColumns(insertPos, num);
266   return insertPos;
267 }
268 
269 unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) {
270   unsigned pos = getNumVarKind(kind);
271   return insertVar(kind, pos, num);
272 }
273 
274 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
275   assert(eq.size() == getNumCols());
276   unsigned row = equalities.appendExtraRow();
277   for (unsigned i = 0, e = eq.size(); i < e; ++i)
278     equalities(row, i) = eq[i];
279 }
280 
281 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
282   assert(inEq.size() == getNumCols());
283   unsigned row = inequalities.appendExtraRow();
284   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
285     inequalities(row, i) = inEq[i];
286 }
287 
288 void IntegerRelation::removeVar(VarKind kind, unsigned pos) {
289   removeVarRange(kind, pos, pos + 1);
290 }
291 
292 void IntegerRelation::removeVar(unsigned pos) { removeVarRange(pos, pos + 1); }
293 
294 void IntegerRelation::removeVarRange(VarKind kind, unsigned varStart,
295                                      unsigned varLimit) {
296   assert(varLimit <= getNumVarKind(kind));
297 
298   if (varStart >= varLimit)
299     return;
300 
301   // Remove eliminated variables from the constraints.
302   unsigned offset = getVarKindOffset(kind);
303   equalities.removeColumns(offset + varStart, varLimit - varStart);
304   inequalities.removeColumns(offset + varStart, varLimit - varStart);
305 
306   // Remove eliminated variables from the space.
307   space.removeVarRange(kind, varStart, varLimit);
308 }
309 
310 void IntegerRelation::removeVarRange(unsigned varStart, unsigned varLimit) {
311   assert(varLimit <= getNumVars());
312 
313   if (varStart >= varLimit)
314     return;
315 
316   // Helper function to remove vars of the specified kind in the given range
317   // [start, limit), The range is absolute (i.e. it is not relative to the kind
318   // of variable). Also updates `limit` to reflect the deleted variables.
319   auto removeVarKindInRange = [this](VarKind kind, unsigned &start,
320                                      unsigned &limit) {
321     if (start >= limit)
322       return;
323 
324     unsigned offset = getVarKindOffset(kind);
325     unsigned num = getNumVarKind(kind);
326 
327     // Get `start`, `limit` relative to the specified kind.
328     unsigned relativeStart =
329         start <= offset ? 0 : std::min(num, start - offset);
330     unsigned relativeLimit =
331         limit <= offset ? 0 : std::min(num, limit - offset);
332 
333     // Remove vars of the specified kind in the relative range.
334     removeVarRange(kind, relativeStart, relativeLimit);
335 
336     // Update `limit` to reflect deleted variables.
337     // `start` does not need to be updated because any variables that are
338     // deleted are after position `start`.
339     limit -= relativeLimit - relativeStart;
340   };
341 
342   removeVarKindInRange(VarKind::Domain, varStart, varLimit);
343   removeVarKindInRange(VarKind::Range, varStart, varLimit);
344   removeVarKindInRange(VarKind::Symbol, varStart, varLimit);
345   removeVarKindInRange(VarKind::Local, varStart, varLimit);
346 }
347 
348 void IntegerRelation::removeEquality(unsigned pos) {
349   equalities.removeRow(pos);
350 }
351 
352 void IntegerRelation::removeInequality(unsigned pos) {
353   inequalities.removeRow(pos);
354 }
355 
356 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
357   if (start >= end)
358     return;
359   equalities.removeRows(start, end - start);
360 }
361 
362 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
363   if (start >= end)
364     return;
365   inequalities.removeRows(start, end - start);
366 }
367 
368 void IntegerRelation::swapVar(unsigned posA, unsigned posB) {
369   assert(posA < getNumVars() && "invalid position A");
370   assert(posB < getNumVars() && "invalid position B");
371 
372   if (posA == posB)
373     return;
374 
375   inequalities.swapColumns(posA, posB);
376   equalities.swapColumns(posA, posB);
377 }
378 
379 void IntegerRelation::clearConstraints() {
380   equalities.resizeVertically(0);
381   inequalities.resizeVertically(0);
382 }
383 
384 /// Gather all lower and upper bounds of the variable at `pos`, and
385 /// optionally any equalities on it. In addition, the bounds are to be
386 /// independent of variables in position range [`offset`, `offset` + `num`).
387 void IntegerRelation::getLowerAndUpperBoundIndices(
388     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
389     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
390     unsigned offset, unsigned num) const {
391   assert(pos < getNumVars() && "invalid position");
392   assert(offset + num < getNumCols() && "invalid range");
393 
394   // Checks for a constraint that has a non-zero coeff for the variables in
395   // the position range [offset, offset + num) while ignoring `pos`.
396   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
397     unsigned c, f;
398     auto cst = isEq ? getEquality(r) : getInequality(r);
399     for (c = offset, f = offset + num; c < f; ++c) {
400       if (c == pos)
401         continue;
402       if (cst[c] != 0)
403         break;
404     }
405     return c < f;
406   };
407 
408   // Gather all lower bounds and upper bounds of the variable. Since the
409   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
410   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
411   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
412     // The bounds are to be independent of [offset, offset + num) columns.
413     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
414       continue;
415     if (atIneq(r, pos) >= 1) {
416       // Lower bound.
417       lbIndices->push_back(r);
418     } else if (atIneq(r, pos) <= -1) {
419       // Upper bound.
420       ubIndices->push_back(r);
421     }
422   }
423 
424   // An equality is both a lower and upper bound. Record any equalities
425   // involving the pos^th variable.
426   if (!eqIndices)
427     return;
428 
429   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
430     if (atEq(r, pos) == 0)
431       continue;
432     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
433       continue;
434     eqIndices->push_back(r);
435   }
436 }
437 
438 bool IntegerRelation::hasConsistentState() const {
439   if (!inequalities.hasConsistentState())
440     return false;
441   if (!equalities.hasConsistentState())
442     return false;
443   return true;
444 }
445 
446 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
447   if (values.empty())
448     return;
449   assert(pos + values.size() <= getNumVars() &&
450          "invalid position or too many values");
451   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
452   // constant term and removing the var x_j. We do this for all the vars
453   // pos, pos + 1, ... pos + values.size() - 1.
454   unsigned constantColPos = getNumCols() - 1;
455   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
456     inequalities.addToColumn(i + pos, constantColPos, values[i]);
457   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
458     equalities.addToColumn(i + pos, constantColPos, values[i]);
459   removeVarRange(pos, pos + values.size());
460 }
461 
462 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
463   *this = other;
464 }
465 
466 // Searches for a constraint with a non-zero coefficient at `colIdx` in
467 // equality (isEq=true) or inequality (isEq=false) constraints.
468 // Returns true and sets row found in search in `rowIdx`, false otherwise.
469 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
470                                                   unsigned *rowIdx) const {
471   assert(colIdx < getNumCols() && "position out of bounds");
472   auto at = [&](unsigned rowIdx) -> int64_t {
473     return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
474   };
475   unsigned e = isEq ? getNumEqualities() : getNumInequalities();
476   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
477     if (at(*rowIdx) != 0) {
478       return true;
479     }
480   }
481   return false;
482 }
483 
484 void IntegerRelation::normalizeConstraintsByGCD() {
485   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
486     equalities.normalizeRow(i);
487   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
488     inequalities.normalizeRow(i);
489 }
490 
491 bool IntegerRelation::hasInvalidConstraint() const {
492   assert(hasConsistentState());
493   auto check = [&](bool isEq) -> bool {
494     unsigned numCols = getNumCols();
495     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
496     for (unsigned i = 0, e = numRows; i < e; ++i) {
497       unsigned j;
498       for (j = 0; j < numCols - 1; ++j) {
499         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
500         // Skip rows with non-zero variable coefficients.
501         if (v != 0)
502           break;
503       }
504       if (j < numCols - 1) {
505         continue;
506       }
507       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
508       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
509       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
510       if ((isEq && v != 0) || (!isEq && v < 0)) {
511         return true;
512       }
513     }
514     return false;
515   };
516   if (check(/*isEq=*/true))
517     return true;
518   return check(/*isEq=*/false);
519 }
520 
521 /// Eliminate variable from constraint at `rowIdx` based on coefficient at
522 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
523 /// updated as they have already been eliminated.
524 static void eliminateFromConstraint(IntegerRelation *constraints,
525                                     unsigned rowIdx, unsigned pivotRow,
526                                     unsigned pivotCol, unsigned elimColStart,
527                                     bool isEq) {
528   // Skip if equality 'rowIdx' if same as 'pivotRow'.
529   if (isEq && rowIdx == pivotRow)
530     return;
531   auto at = [&](unsigned i, unsigned j) -> int64_t {
532     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
533   };
534   int64_t leadCoeff = at(rowIdx, pivotCol);
535   // Skip if leading coefficient at 'rowIdx' is already zero.
536   if (leadCoeff == 0)
537     return;
538   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
539   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
540   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
541   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
542   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
543 
544   unsigned numCols = constraints->getNumCols();
545   for (unsigned j = 0; j < numCols; ++j) {
546     // Skip updating column 'j' if it was just eliminated.
547     if (j >= elimColStart && j < pivotCol)
548       continue;
549     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
550                 rowMultiplier * at(rowIdx, j);
551     isEq ? constraints->atEq(rowIdx, j) = v
552          : constraints->atIneq(rowIdx, j) = v;
553   }
554 }
555 
556 /// Returns the position of the variable that has the minimum <number of lower
557 /// bounds> times <number of upper bounds> from the specified range of
558 /// variables [start, end). It is often best to eliminate in the increasing
559 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
560 /// that many new constraints.
561 static unsigned getBestVarToEliminate(const IntegerRelation &cst,
562                                       unsigned start, unsigned end) {
563   assert(start < cst.getNumVars() && end < cst.getNumVars() + 1);
564 
565   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
566     unsigned numLb = 0;
567     unsigned numUb = 0;
568     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
569       if (cst.atIneq(r, pos) > 0) {
570         ++numLb;
571       } else if (cst.atIneq(r, pos) < 0) {
572         ++numUb;
573       }
574     }
575     return numLb * numUb;
576   };
577 
578   unsigned minLoc = start;
579   unsigned min = getProductOfNumLowerUpperBounds(start);
580   for (unsigned c = start + 1; c < end; c++) {
581     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
582     if (numLbUbProduct < min) {
583       min = numLbUbProduct;
584       minLoc = c;
585     }
586   }
587   return minLoc;
588 }
589 
590 // Checks for emptiness of the set by eliminating variables successively and
591 // using the GCD test (on all equality constraints) and checking for trivially
592 // invalid constraints. Returns 'true' if the constraint system is found to be
593 // empty; false otherwise.
594 bool IntegerRelation::isEmpty() const {
595   if (isEmptyByGCDTest() || hasInvalidConstraint())
596     return true;
597 
598   IntegerRelation tmpCst(*this);
599 
600   // First, eliminate as many local variables as possible using equalities.
601   tmpCst.removeRedundantLocalVars();
602   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
603     return true;
604 
605   // Eliminate as many variables as possible using Gaussian elimination.
606   unsigned currentPos = 0;
607   while (currentPos < tmpCst.getNumVars()) {
608     tmpCst.gaussianEliminateVars(currentPos, tmpCst.getNumVars());
609     ++currentPos;
610     // We check emptiness through trivial checks after eliminating each ID to
611     // detect emptiness early. Since the checks isEmptyByGCDTest() and
612     // hasInvalidConstraint() are linear time and single sweep on the constraint
613     // buffer, this appears reasonable - but can optimize in the future.
614     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
615       return true;
616   }
617 
618   // Eliminate the remaining using FM.
619   for (unsigned i = 0, e = tmpCst.getNumVars(); i < e; i++) {
620     tmpCst.fourierMotzkinEliminate(
621         getBestVarToEliminate(tmpCst, 0, tmpCst.getNumVars()));
622     // Check for a constraint explosion. This rarely happens in practice, but
623     // this check exists as a safeguard against improperly constructed
624     // constraint systems or artificially created arbitrarily complex systems
625     // that aren't the intended use case for IntegerRelation. This is
626     // needed since FM has a worst case exponential complexity in theory.
627     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumVars()) {
628       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
629       return false;
630     }
631 
632     // FM wouldn't have modified the equalities in any way. So no need to again
633     // run GCD test. Check for trivial invalid constraints.
634     if (tmpCst.hasInvalidConstraint())
635       return true;
636   }
637   return false;
638 }
639 
640 // Runs the GCD test on all equality constraints. Returns 'true' if this test
641 // fails on any equality. Returns 'false' otherwise.
642 // This test can be used to disprove the existence of a solution. If it returns
643 // true, no integer solution to the equality constraints can exist.
644 //
645 // GCD test definition:
646 //
647 // The equality constraint:
648 //
649 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
650 //
651 // has an integer solution iff:
652 //
653 //  GCD of c_1, c_2, ..., c_n divides c_0.
654 //
655 bool IntegerRelation::isEmptyByGCDTest() const {
656   assert(hasConsistentState());
657   unsigned numCols = getNumCols();
658   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
659     uint64_t gcd = std::abs(atEq(i, 0));
660     for (unsigned j = 1; j < numCols - 1; ++j) {
661       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
662     }
663     int64_t v = std::abs(atEq(i, numCols - 1));
664     if (gcd > 0 && (v % gcd != 0)) {
665       return true;
666     }
667   }
668   return false;
669 }
670 
671 // Returns a matrix where each row is a vector along which the polytope is
672 // bounded. The span of the returned vectors is guaranteed to contain all
673 // such vectors. The returned vectors are NOT guaranteed to be linearly
674 // independent. This function should not be called on empty sets.
675 //
676 // It is sufficient to check the perpendiculars of the constraints, as the set
677 // of perpendiculars which are bounded must span all bounded directions.
678 Matrix IntegerRelation::getBoundedDirections() const {
679   // Note that it is necessary to add the equalities too (which the constructor
680   // does) even though we don't need to check if they are bounded; whether an
681   // inequality is bounded or not depends on what other constraints, including
682   // equalities, are present.
683   Simplex simplex(*this);
684 
685   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
686                                "direction is bounded in an empty set.");
687 
688   SmallVector<unsigned, 8> boundedIneqs;
689   // The constructor adds the inequalities to the simplex first, so this
690   // processes all the inequalities.
691   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
692     if (simplex.isBoundedAlongConstraint(i))
693       boundedIneqs.push_back(i);
694   }
695 
696   // The direction vector is given by the coefficients and does not include the
697   // constant term, so the matrix has one fewer column.
698   unsigned dirsNumCols = getNumCols() - 1;
699   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
700 
701   // Copy the bounded inequalities.
702   unsigned row = 0;
703   for (unsigned i : boundedIneqs) {
704     for (unsigned col = 0; col < dirsNumCols; ++col)
705       dirs(row, col) = atIneq(i, col);
706     ++row;
707   }
708 
709   // Copy the equalities. All the equalities' perpendiculars are bounded.
710   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
711     for (unsigned col = 0; col < dirsNumCols; ++col)
712       dirs(row, col) = atEq(i, col);
713     ++row;
714   }
715 
716   return dirs;
717 }
718 
719 bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); }
720 
721 /// Let this set be S. If S is bounded then we directly call into the GBR
722 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
723 /// vectors v such that S extends to infinity along v or -v. In this case we
724 /// use an algorithm described in the integer set library (isl) manual and used
725 /// by the isl_set_sample function in that library. The algorithm is:
726 ///
727 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
728 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
729 /// dimensions.
730 ///
731 /// 2) Construct a set B by removing all constraints that involve
732 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
733 /// that B is a Bounded set.
734 ///
735 /// 3) Try to obtain a sample from B using the GBR sampling
736 /// algorithm. If no sample is found, return that S is empty.
737 ///
738 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
739 /// C. C is a full-dimensional Cone and always contains a sample.
740 ///
741 /// 5) Obtain an integer sample from C.
742 ///
743 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
744 ///
745 /// The following is a sketch of a proof that
746 /// a) If the algorithm returns empty, then S is empty.
747 /// b) If the algorithm returns a sample, it is a valid sample in S.
748 ///
749 /// The algorithm returns empty only if B is empty, in which case S*T is
750 /// certainly empty since B was obtained by removing constraints and then
751 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
752 /// v is in S*T iff T*v is in S. So in this case, since
753 /// S*T is empty, S is empty too.
754 ///
755 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
756 /// constraints of S*T that did not involve unbounded dimensions are satisfied
757 /// by this substitution. All dimensions in the linear span of the dimensions
758 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
759 /// the bounded dimensions cannot make these dimensions bounded, and these are
760 /// the only remaining dimensions in C, so C is unbounded along every vector (in
761 /// the positive or negative direction, or both). C is hence a full-dimensional
762 /// cone and therefore always contains an integer point.
763 ///
764 /// Concatenating the samples from B and C gives a sample v in S*T, so the
765 /// returned sample T*v is a sample in S.
766 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
767   // First, try the GCD test heuristic.
768   if (isEmptyByGCDTest())
769     return {};
770 
771   Simplex simplex(*this);
772   if (simplex.isEmpty())
773     return {};
774 
775   // For a bounded set, we directly call into the GBR sampling algorithm.
776   if (!simplex.isUnbounded())
777     return simplex.findIntegerSample();
778 
779   // The set is unbounded. We cannot directly use the GBR algorithm.
780   //
781   // m is a matrix containing, in each row, a vector in which S is
782   // bounded, such that the linear span of all these dimensions contains all
783   // bounded dimensions in S.
784   Matrix m = getBoundedDirections();
785   // In column echelon form, each row of m occupies only the first rank(m)
786   // columns and has zeros on the other columns. The transform T that brings S
787   // to column echelon form is unimodular as well, so this is a suitable
788   // transform to use in step 1 of the algorithm.
789   std::pair<unsigned, LinearTransform> result =
790       LinearTransform::makeTransformToColumnEchelon(std::move(m));
791   const LinearTransform &transform = result.second;
792   // 1) Apply T to S to obtain S*T.
793   IntegerRelation transformedSet = transform.applyTo(*this);
794 
795   // 2) Remove the unbounded dimensions and constraints involving them to
796   // obtain a bounded set.
797   IntegerRelation boundedSet(transformedSet);
798   unsigned numBoundedDims = result.first;
799   unsigned numUnboundedDims = getNumVars() - numBoundedDims;
800   removeConstraintsInvolvingVarRange(boundedSet, numBoundedDims,
801                                      numUnboundedDims);
802   boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars());
803 
804   // 3) Try to obtain a sample from the bounded set.
805   Optional<SmallVector<int64_t, 8>> boundedSample =
806       Simplex(boundedSet).findIntegerSample();
807   if (!boundedSample)
808     return {};
809   assert(boundedSet.containsPoint(*boundedSample) &&
810          "Simplex returned an invalid sample!");
811 
812   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
813   // full-dimensional cone, which necessarily contains an integer sample.
814   transformedSet.setAndEliminate(0, *boundedSample);
815   IntegerRelation &cone = transformedSet;
816 
817   // 5) Obtain an integer sample from the cone.
818   //
819   // We shrink the cone such that for any rational point in the shrunken cone,
820   // rounding up each of the point's coordinates produces a point that still
821   // lies in the original cone.
822   //
823   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
824   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
825   // shrunken cone will have the inequality tightened by some amount s, such
826   // that if x satisfies the shrunken cone's tightened inequality, then x + e
827   // satisfies the original inequality, i.e.,
828   //
829   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
830   //
831   // for any e_i values in [0, 1). In fact, we will handle the slightly more
832   // general case where e_i can be in [0, 1]. For example, consider the
833   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
834   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
835   // is minimized when we add 1 to the x_i with negative coefficient a_i and
836   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
837   // changing the value of the LHS by -3 + -7 = -10.
838   //
839   // In general, the value of the LHS can change by at most the sum of the
840   // negative a_i, so we accomodate this by shifting the inequality by this
841   // amount for the shrunken cone.
842   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
843     for (unsigned j = 0; j < cone.getNumVars(); ++j) {
844       int64_t coeff = cone.atIneq(i, j);
845       if (coeff < 0)
846         cone.atIneq(i, cone.getNumVars()) += coeff;
847     }
848   }
849 
850   // Obtain an integer sample in the cone by rounding up a rational point from
851   // the shrunken cone. Shrinking the cone amounts to shifting its apex
852   // "inwards" without changing its "shape"; the shrunken cone is still a
853   // full-dimensional cone and is hence non-empty.
854   Simplex shrunkenConeSimplex(cone);
855   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
856 
857   // The sample will always exist since the shrunken cone is non-empty.
858   SmallVector<Fraction, 8> shrunkenConeSample =
859       *shrunkenConeSimplex.getRationalSample();
860 
861   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
862 
863   // 6) Return transform * concat(boundedSample, coneSample).
864   SmallVector<int64_t, 8> &sample = *boundedSample;
865   sample.append(coneSample.begin(), coneSample.end());
866   return transform.postMultiplyWithColumn(sample);
867 }
868 
869 /// Helper to evaluate an affine expression at a point.
870 /// The expression is a list of coefficients for the dimensions followed by the
871 /// constant term.
872 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
873   assert(expr.size() == 1 + point.size() &&
874          "Dimensionalities of point and expression don't match!");
875   int64_t value = expr.back();
876   for (unsigned i = 0; i < point.size(); ++i)
877     value += expr[i] * point[i];
878   return value;
879 }
880 
881 /// A point satisfies an equality iff the value of the equality at the
882 /// expression is zero, and it satisfies an inequality iff the value of the
883 /// inequality at that point is non-negative.
884 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
885   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
886     if (valueAt(getEquality(i), point) != 0)
887       return false;
888   }
889   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
890     if (valueAt(getInequality(i), point) < 0)
891       return false;
892   }
893   return true;
894 }
895 
896 /// Just substitute the values given and check if an integer sample exists for
897 /// the local vars.
898 ///
899 /// TODO: this could be made more efficient by handling divisions separately.
900 /// Instead of finding an integer sample over all the locals, we can first
901 /// compute the values of the locals that have division representations and
902 /// only use the integer emptiness check for the locals that don't have this.
903 /// Handling this correctly requires ordering the divs, though.
904 Optional<SmallVector<int64_t, 8>>
905 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
906   assert(point.size() == getNumVars() - getNumLocalVars() &&
907          "Point should contain all vars except locals!");
908   assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() &&
909          "This function depends on locals being stored last!");
910   IntegerRelation copy = *this;
911   copy.setAndEliminate(0, point);
912   return copy.findIntegerSample();
913 }
914 
915 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
916   std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalVars());
917   SmallVector<unsigned, 4> denominators(getNumLocalVars());
918   getLocalReprs(dividends, denominators, repr);
919 }
920 
921 void IntegerRelation::getLocalReprs(
922     std::vector<SmallVector<int64_t, 8>> &dividends,
923     SmallVector<unsigned, 4> &denominators) const {
924   std::vector<MaybeLocalRepr> repr(getNumLocalVars());
925   getLocalReprs(dividends, denominators, repr);
926 }
927 
928 void IntegerRelation::getLocalReprs(
929     std::vector<SmallVector<int64_t, 8>> &dividends,
930     SmallVector<unsigned, 4> &denominators,
931     std::vector<MaybeLocalRepr> &repr) const {
932 
933   repr.resize(getNumLocalVars());
934   dividends.resize(getNumLocalVars());
935   denominators.resize(getNumLocalVars());
936 
937   SmallVector<bool, 8> foundRepr(getNumVars(), false);
938   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
939     foundRepr[i] = true;
940 
941   unsigned divOffset = getNumDimAndSymbolVars();
942   bool changed;
943   do {
944     // Each time changed is true, at end of this iteration, one or more local
945     // vars have been detected as floor divs.
946     changed = false;
947     for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) {
948       if (!foundRepr[i + divOffset]) {
949         MaybeLocalRepr res = computeSingleVarRepr(
950             *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
951         if (!res)
952           continue;
953         foundRepr[i + divOffset] = true;
954         repr[i] = res;
955         changed = true;
956       }
957     }
958   } while (changed);
959 
960   // Set 0 denominator for variables for which no division representation
961   // could be found.
962   for (unsigned i = 0, e = repr.size(); i < e; ++i)
963     if (!repr[i])
964       denominators[i] = 0;
965 }
966 
967 /// Tightens inequalities given that we are dealing with integer spaces. This is
968 /// analogous to the GCD test but applied to inequalities. The constant term can
969 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
970 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
971 /// fast method - linear in the number of coefficients.
972 // Example on how this affects practical cases: consider the scenario:
973 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
974 // j >= 100 instead of the tighter (exact) j >= 128.
975 void IntegerRelation::gcdTightenInequalities() {
976   unsigned numCols = getNumCols();
977   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
978     // Normalize the constraint and tighten the constant term by the GCD.
979     int64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
980     if (gcd > 1)
981       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
982   }
983 }
984 
985 // Eliminates all variable variables in column range [posStart, posLimit).
986 // Returns the number of variables eliminated.
987 unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
988                                                 unsigned posLimit) {
989   // Return if variable positions to eliminate are out of range.
990   assert(posLimit <= getNumVars());
991   assert(hasConsistentState());
992 
993   if (posStart >= posLimit)
994     return 0;
995 
996   gcdTightenInequalities();
997 
998   unsigned pivotCol = 0;
999   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1000     // Find a row which has a non-zero coefficient in column 'j'.
1001     unsigned pivotRow;
1002     if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
1003       // No pivot row in equalities with non-zero at 'pivotCol'.
1004       if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
1005         // If inequalities are also non-zero in 'pivotCol', it can be
1006         // eliminated.
1007         continue;
1008       }
1009       break;
1010     }
1011 
1012     // Eliminate variable at 'pivotCol' from each equality row.
1013     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1014       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1015                               /*isEq=*/true);
1016       equalities.normalizeRow(i);
1017     }
1018 
1019     // Eliminate variable at 'pivotCol' from each inequality row.
1020     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1021       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1022                               /*isEq=*/false);
1023       inequalities.normalizeRow(i);
1024     }
1025     removeEquality(pivotRow);
1026     gcdTightenInequalities();
1027   }
1028   // Update position limit based on number eliminated.
1029   posLimit = pivotCol;
1030   // Remove eliminated columns from all constraints.
1031   removeVarRange(posStart, posLimit);
1032   return posLimit - posStart;
1033 }
1034 
1035 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1036 // to check if a constraint is redundant.
1037 void IntegerRelation::removeRedundantInequalities() {
1038   SmallVector<bool, 32> redun(getNumInequalities(), false);
1039   // To check if an inequality is redundant, we replace the inequality by its
1040   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1041   // system is empty. If it is, the inequality is redundant.
1042   IntegerRelation tmpCst(*this);
1043   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1044     // Change the inequality to its complement.
1045     tmpCst.inequalities.negateRow(r);
1046     --tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
1047     if (tmpCst.isEmpty()) {
1048       redun[r] = true;
1049       // Zero fill the redundant inequality.
1050       inequalities.fillRow(r, /*value=*/0);
1051       tmpCst.inequalities.fillRow(r, /*value=*/0);
1052     } else {
1053       // Reverse the change (to avoid recreating tmpCst each time).
1054       ++tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
1055       tmpCst.inequalities.negateRow(r);
1056     }
1057   }
1058 
1059   unsigned pos = 0;
1060   for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
1061     if (!redun[r])
1062       inequalities.copyRow(r, pos++);
1063   }
1064   inequalities.resizeVertically(pos);
1065 }
1066 
1067 // A more complex check to eliminate redundant inequalities and equalities. Uses
1068 // Simplex to check if a constraint is redundant.
1069 void IntegerRelation::removeRedundantConstraints() {
1070   // First, we run gcdTightenInequalities. This allows us to catch some
1071   // constraints which are not redundant when considering rational solutions
1072   // but are redundant in terms of integer solutions.
1073   gcdTightenInequalities();
1074   Simplex simplex(*this);
1075   simplex.detectRedundant();
1076 
1077   unsigned pos = 0;
1078   unsigned numIneqs = getNumInequalities();
1079   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1080   // the first constraints added are the inequalities.
1081   for (unsigned r = 0; r < numIneqs; r++) {
1082     if (!simplex.isMarkedRedundant(r))
1083       inequalities.copyRow(r, pos++);
1084   }
1085   inequalities.resizeVertically(pos);
1086 
1087   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1088   // after the inequalities, a pair of constraints for each equality is added.
1089   // An equality is redundant if both the inequalities in its pair are
1090   // redundant.
1091   pos = 0;
1092   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1093     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1094           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1095       equalities.copyRow(r, pos++);
1096   }
1097   equalities.resizeVertically(pos);
1098 }
1099 
1100 Optional<uint64_t> IntegerRelation::computeVolume() const {
1101   assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
1102 
1103   Simplex simplex(*this);
1104   // If the polytope is rationally empty, there are certainly no integer
1105   // points.
1106   if (simplex.isEmpty())
1107     return 0;
1108 
1109   // Just find the maximum and minimum integer value of each non-local var
1110   // separately, thus finding the number of integer values each such var can
1111   // take. Multiplying these together gives a valid overapproximation of the
1112   // number of integer points in the relation. The result this gives is
1113   // equivalent to projecting (rationally) the relation onto its non-local vars
1114   // and returning the number of integer points in a minimal axis-parallel
1115   // hyperrectangular overapproximation of that.
1116   //
1117   // We also handle the special case where one dimension is unbounded and
1118   // another dimension can take no integer values. In this case, the volume is
1119   // zero.
1120   //
1121   // If there is no such empty dimension, if any dimension is unbounded we
1122   // just return the result as unbounded.
1123   uint64_t count = 1;
1124   SmallVector<int64_t, 8> dim(getNumVars() + 1);
1125   bool hasUnboundedVar = false;
1126   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) {
1127     dim[i] = 1;
1128     MaybeOptimum<int64_t> min, max;
1129     std::tie(min, max) = simplex.computeIntegerBounds(dim);
1130     dim[i] = 0;
1131 
1132     assert((!min.isEmpty() && !max.isEmpty()) &&
1133            "Polytope should be rationally non-empty!");
1134 
1135     // One of the dimensions is unbounded. Note this fact. We will return
1136     // unbounded if none of the other dimensions makes the volume zero.
1137     if (min.isUnbounded() || max.isUnbounded()) {
1138       hasUnboundedVar = true;
1139       continue;
1140     }
1141 
1142     // In this case there are no valid integer points and the volume is
1143     // definitely zero.
1144     if (min.getBoundedOptimum() > max.getBoundedOptimum())
1145       return 0;
1146 
1147     count *= (*max - *min + 1);
1148   }
1149 
1150   if (count == 0)
1151     return 0;
1152   if (hasUnboundedVar)
1153     return {};
1154   return count;
1155 }
1156 
1157 void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
1158   assert(posA < getNumLocalVars() && "Invalid local var position");
1159   assert(posB < getNumLocalVars() && "Invalid local var position");
1160 
1161   unsigned localOffset = getVarKindOffset(VarKind::Local);
1162   posA += localOffset;
1163   posB += localOffset;
1164   inequalities.addToColumn(posB, posA, 1);
1165   equalities.addToColumn(posB, posA, 1);
1166   removeVar(posB);
1167 }
1168 
1169 /// Adds additional local ids to the sets such that they both have the union
1170 /// of the local ids in each set, without changing the set of points that
1171 /// lie in `this` and `other`.
1172 ///
1173 /// To detect local ids that always take the same value, each local id is
1174 /// represented as a floordiv with constant denominator in terms of other ids.
1175 /// After extracting these divisions, local ids in `other` with the same
1176 /// division representation as some other local id in any set are considered
1177 /// duplicate and are merged.
1178 ///
1179 /// It is possible that division representation for some local id cannot be
1180 /// obtained, and thus these local ids are not considered for detecting
1181 /// duplicates.
1182 unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) {
1183   IntegerRelation &relA = *this;
1184   IntegerRelation &relB = other;
1185 
1186   unsigned oldALocals = relA.getNumLocalVars();
1187 
1188   // Merge function that merges the local variables in both sets by treating
1189   // them as the same variable.
1190   auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool {
1191     // We only merge from local at pos j to local at pos i, where j > i.
1192     if (i >= j)
1193       return false;
1194 
1195     // If i < oldALocals, we are trying to merge duplicate divs. Since we do not
1196     // want to merge duplicates in A, we ignore this call.
1197     if (j < oldALocals)
1198       return false;
1199 
1200     // Merge local at pos j into local at position i.
1201     relA.eliminateRedundantLocalVar(i, j);
1202     relB.eliminateRedundantLocalVar(i, j);
1203     return true;
1204   };
1205 
1206   presburger::mergeLocalVars(*this, other, merge);
1207 
1208   // Since we do not remove duplicate divisions in relA, this is guranteed to be
1209   // non-negative.
1210   return relA.getNumLocalVars() - oldALocals;
1211 }
1212 
1213 bool IntegerRelation::hasOnlyDivLocals() const {
1214   std::vector<MaybeLocalRepr> reprs;
1215   getLocalReprs(reprs);
1216   return llvm::all_of(reprs,
1217                       [](const MaybeLocalRepr &repr) { return bool(repr); });
1218 }
1219 
1220 void IntegerRelation::removeDuplicateDivs() {
1221   std::vector<SmallVector<int64_t, 8>> divs;
1222   SmallVector<unsigned, 4> denoms;
1223 
1224   getLocalReprs(divs, denoms);
1225   auto merge = [this](unsigned i, unsigned j) -> bool {
1226     eliminateRedundantLocalVar(i, j);
1227     return true;
1228   };
1229   presburger::removeDuplicateDivs(divs, denoms,
1230                                   getVarKindOffset(VarKind::Local), merge);
1231 }
1232 
1233 /// Removes local variables using equalities. Each equality is checked if it
1234 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1235 /// variable and `affine-expr` is an affine expression not containing `e`.
1236 /// If an equality satisfies this form, the local variable is replaced in
1237 /// each constraint and then removed. The equality used to replace this local
1238 /// variable is also removed.
1239 void IntegerRelation::removeRedundantLocalVars() {
1240   // Normalize the equality constraints to reduce coefficients of local
1241   // variables to 1 wherever possible.
1242   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1243     equalities.normalizeRow(i);
1244 
1245   while (true) {
1246     unsigned i, e, j, f;
1247     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1248       // Find a local variable to eliminate using ith equality.
1249       for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j)
1250         if (std::abs(atEq(i, j)) == 1)
1251           break;
1252 
1253       // Local variable can be eliminated using ith equality.
1254       if (j < f)
1255         break;
1256     }
1257 
1258     // No equality can be used to eliminate a local variable.
1259     if (i == e)
1260       break;
1261 
1262     // Use the ith equality to simplify other equalities. If any changes
1263     // are made to an equality constraint, it is normalized by GCD.
1264     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1265       if (atEq(k, j) != 0) {
1266         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1267         equalities.normalizeRow(k);
1268       }
1269     }
1270 
1271     // Use the ith equality to simplify inequalities.
1272     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1273       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1274 
1275     // Remove the ith equality and the found local variable.
1276     removeVar(j);
1277     removeEquality(i);
1278   }
1279 }
1280 
1281 void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
1282                                      unsigned varLimit, VarKind dstKind,
1283                                      unsigned pos) {
1284   assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range");
1285 
1286   if (varStart >= varLimit)
1287     return;
1288 
1289   // Append new local variables corresponding to the dimensions to be converted.
1290   unsigned convertCount = varLimit - varStart;
1291   unsigned newVarsBegin = insertVar(dstKind, pos, convertCount);
1292 
1293   // Swap the new local variables with dimensions.
1294   //
1295   // Essentially, this moves the information corresponding to the specified ids
1296   // of kind `srcKind` to the `convertCount` newly created ids of kind
1297   // `dstKind`. In particular, this moves the columns in the constraint
1298   // matrices, and zeros out the initially occupied columns (because the newly
1299   // created ids we're swapping with were zero-initialized).
1300   unsigned offset = getVarKindOffset(srcKind);
1301   for (unsigned i = 0; i < convertCount; ++i)
1302     swapVar(offset + varStart + i, newVarsBegin + i);
1303 
1304   // Complete the move by deleting the initially occupied columns.
1305   removeVarRange(srcKind, varStart, varLimit);
1306 }
1307 
1308 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
1309   assert(pos < getNumCols());
1310   if (type == BoundType::EQ) {
1311     unsigned row = equalities.appendExtraRow();
1312     equalities(row, pos) = 1;
1313     equalities(row, getNumCols() - 1) = -value;
1314   } else {
1315     unsigned row = inequalities.appendExtraRow();
1316     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
1317     inequalities(row, getNumCols() - 1) =
1318         type == BoundType::LB ? -value : value;
1319   }
1320 }
1321 
1322 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
1323                                int64_t value) {
1324   assert(type != BoundType::EQ && "EQ not implemented");
1325   assert(expr.size() == getNumCols());
1326   unsigned row = inequalities.appendExtraRow();
1327   for (unsigned i = 0, e = expr.size(); i < e; ++i)
1328     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
1329   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
1330       type == BoundType::LB ? -value : value;
1331 }
1332 
1333 /// Adds a new local variable as the floordiv of an affine function of other
1334 /// variables, the coefficients of which are provided in 'dividend' and with
1335 /// respect to a positive constant 'divisor'. Two constraints are added to the
1336 /// system to capture equivalence with the floordiv.
1337 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
1338 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1339                                        int64_t divisor) {
1340   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1341   assert(divisor > 0 && "positive divisor expected");
1342 
1343   appendVar(VarKind::Local);
1344 
1345   SmallVector<int64_t, 8> dividendCopy(dividend.begin(), dividend.end());
1346   dividendCopy.insert(dividendCopy.end() - 1, 0);
1347   addInequality(
1348       getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
1349   addInequality(
1350       getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
1351 }
1352 
1353 /// Finds an equality that equates the specified variable to a constant.
1354 /// Returns the position of the equality row. If 'symbolic' is set to true,
1355 /// symbols are also treated like a constant, i.e., an affine function of the
1356 /// symbols is also treated like a constant. Returns -1 if such an equality
1357 /// could not be found.
1358 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
1359                                   bool symbolic = false) {
1360   assert(pos < cst.getNumVars() && "invalid position");
1361   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1362     int64_t v = cst.atEq(r, pos);
1363     if (v * v != 1)
1364       continue;
1365     unsigned c;
1366     unsigned f = symbolic ? cst.getNumDimVars() : cst.getNumVars();
1367     // This checks for zeros in all positions other than 'pos' in [0, f)
1368     for (c = 0; c < f; c++) {
1369       if (c == pos)
1370         continue;
1371       if (cst.atEq(r, c) != 0) {
1372         // Dependent on another variable.
1373         break;
1374       }
1375     }
1376     if (c == f)
1377       // Equality is free of other variables.
1378       return r;
1379   }
1380   return -1;
1381 }
1382 
1383 LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
1384   assert(pos < getNumVars() && "invalid position");
1385   int rowIdx;
1386   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
1387     return failure();
1388 
1389   // atEq(rowIdx, pos) is either -1 or 1.
1390   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
1391   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
1392   setAndEliminate(pos, constVal);
1393   return success();
1394 }
1395 
1396 void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
1397   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
1398     if (failed(constantFoldVar(t)))
1399       t++;
1400   }
1401 }
1402 
1403 /// Returns a non-negative constant bound on the extent (upper bound - lower
1404 /// bound) of the specified variable if it is found to be a constant; returns
1405 /// None if it's not a constant. This methods treats symbolic variables
1406 /// specially, i.e., it looks for constant differences between affine
1407 /// expressions involving only the symbolic variables. See comments at
1408 /// function definition for example. 'lb', if provided, is set to the lower
1409 /// bound associated with the constant difference. Note that 'lb' is purely
1410 /// symbolic and thus will contain the coefficients of the symbolic variables
1411 /// and the constant coefficient.
1412 //  Egs: 0 <= i <= 15, return 16.
1413 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
1414 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
1415 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
1416 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
1417 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
1418     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
1419     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
1420     unsigned *minUbPos) const {
1421   assert(pos < getNumDimVars() && "Invalid variable position");
1422 
1423   // Find an equality for 'pos'^th variable that equates it to some function
1424   // of the symbolic variables (+ constant).
1425   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
1426   if (eqPos != -1) {
1427     auto eq = getEquality(eqPos);
1428     // If the equality involves a local var, punt for now.
1429     // TODO: this can be handled in the future by using the explicit
1430     // representation of the local vars.
1431     if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
1432                      [](int64_t coeff) { return coeff == 0; }))
1433       return None;
1434 
1435     // This variable can only take a single value.
1436     if (lb) {
1437       // Set lb to that symbolic value.
1438       lb->resize(getNumSymbolVars() + 1);
1439       if (ub)
1440         ub->resize(getNumSymbolVars() + 1);
1441       for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
1442         int64_t v = atEq(eqPos, pos);
1443         // atEq(eqRow, pos) is either -1 or 1.
1444         assert(v * v == 1);
1445         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
1446                          : -atEq(eqPos, getNumDimVars() + c) / v;
1447         // Since this is an equality, ub = lb.
1448         if (ub)
1449           (*ub)[c] = (*lb)[c];
1450       }
1451       assert(boundFloorDivisor &&
1452              "both lb and divisor or none should be provided");
1453       *boundFloorDivisor = 1;
1454     }
1455     if (minLbPos)
1456       *minLbPos = eqPos;
1457     if (minUbPos)
1458       *minUbPos = eqPos;
1459     return 1;
1460   }
1461 
1462   // Check if the variable appears at all in any of the inequalities.
1463   unsigned r, e;
1464   for (r = 0, e = getNumInequalities(); r < e; r++) {
1465     if (atIneq(r, pos) != 0)
1466       break;
1467   }
1468   if (r == e)
1469     // If it doesn't, there isn't a bound on it.
1470     return None;
1471 
1472   // Positions of constraints that are lower/upper bounds on the variable.
1473   SmallVector<unsigned, 4> lbIndices, ubIndices;
1474 
1475   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
1476   // the bounds can only involve symbolic (and local) variables. Since the
1477   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1478   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1479   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
1480                                /*eqIndices=*/nullptr, /*offset=*/0,
1481                                /*num=*/getNumDimVars());
1482 
1483   Optional<int64_t> minDiff = None;
1484   unsigned minLbPosition = 0, minUbPosition = 0;
1485   for (auto ubPos : ubIndices) {
1486     for (auto lbPos : lbIndices) {
1487       // Look for a lower bound and an upper bound that only differ by a
1488       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
1489       // For example, if ii is the pos^th variable, we are looking for
1490       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
1491       // minimum among all such constant differences is kept since that's the
1492       // constant bounding the extent of the pos^th variable.
1493       unsigned j, e;
1494       for (j = 0, e = getNumCols() - 1; j < e; j++)
1495         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
1496           break;
1497         }
1498       if (j < getNumCols() - 1)
1499         continue;
1500       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
1501                                  atIneq(lbPos, getNumCols() - 1) + 1,
1502                              atIneq(lbPos, pos));
1503       // This bound is non-negative by definition.
1504       diff = std::max<int64_t>(diff, 0);
1505       if (minDiff == None || diff < minDiff) {
1506         minDiff = diff;
1507         minLbPosition = lbPos;
1508         minUbPosition = ubPos;
1509       }
1510     }
1511   }
1512   if (lb && minDiff) {
1513     // Set lb to the symbolic lower bound.
1514     lb->resize(getNumSymbolVars() + 1);
1515     if (ub)
1516       ub->resize(getNumSymbolVars() + 1);
1517     // The lower bound is the ceildiv of the lb constraint over the coefficient
1518     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
1519     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
1520     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
1521     *boundFloorDivisor = atIneq(minLbPosition, pos);
1522     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1523     for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
1524       (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
1525     }
1526     if (ub) {
1527       for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
1528         (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
1529     }
1530     // The lower bound leads to a ceildiv while the upper bound is a floordiv
1531     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
1532     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
1533     // the constant term for the lower bound.
1534     (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
1535   }
1536   if (minLbPos)
1537     *minLbPos = minLbPosition;
1538   if (minUbPos)
1539     *minUbPos = minUbPosition;
1540   return minDiff;
1541 }
1542 
1543 template <bool isLower>
1544 Optional<int64_t>
1545 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
1546   assert(pos < getNumVars() && "invalid position");
1547   // Project to 'pos'.
1548   projectOut(0, pos);
1549   projectOut(1, getNumVars() - 1);
1550   // Check if there's an equality equating the '0'^th variable to a constant.
1551   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
1552   if (eqRowIdx != -1)
1553     // atEq(rowIdx, 0) is either -1 or 1.
1554     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
1555 
1556   // Check if the variable appears at all in any of the inequalities.
1557   unsigned r, e;
1558   for (r = 0, e = getNumInequalities(); r < e; r++) {
1559     if (atIneq(r, 0) != 0)
1560       break;
1561   }
1562   if (r == e)
1563     // If it doesn't, there isn't a bound on it.
1564     return None;
1565 
1566   Optional<int64_t> minOrMaxConst = None;
1567 
1568   // Take the max across all const lower bounds (or min across all constant
1569   // upper bounds).
1570   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1571     if (isLower) {
1572       if (atIneq(r, 0) <= 0)
1573         // Not a lower bound.
1574         continue;
1575     } else if (atIneq(r, 0) >= 0) {
1576       // Not an upper bound.
1577       continue;
1578     }
1579     unsigned c, f;
1580     for (c = 0, f = getNumCols() - 1; c < f; c++)
1581       if (c != 0 && atIneq(r, c) != 0)
1582         break;
1583     if (c < getNumCols() - 1)
1584       // Not a constant bound.
1585       continue;
1586 
1587     int64_t boundConst =
1588         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
1589                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
1590     if (isLower) {
1591       if (minOrMaxConst == None || boundConst > minOrMaxConst)
1592         minOrMaxConst = boundConst;
1593     } else {
1594       if (minOrMaxConst == None || boundConst < minOrMaxConst)
1595         minOrMaxConst = boundConst;
1596     }
1597   }
1598   return minOrMaxConst;
1599 }
1600 
1601 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
1602                                                     unsigned pos) const {
1603   if (type == BoundType::LB)
1604     return IntegerRelation(*this)
1605         .computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1606   if (type == BoundType::UB)
1607     return IntegerRelation(*this)
1608         .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1609 
1610   assert(type == BoundType::EQ && "expected EQ");
1611   Optional<int64_t> lb =
1612       IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
1613           pos);
1614   Optional<int64_t> ub =
1615       IntegerRelation(*this)
1616           .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1617   return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
1618 }
1619 
1620 // A simple (naive and conservative) check for hyper-rectangularity.
1621 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
1622   assert(pos < getNumCols() - 1);
1623   // Check for two non-zero coefficients in the range [pos, pos + sum).
1624   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1625     unsigned sum = 0;
1626     for (unsigned c = pos; c < pos + num; c++) {
1627       if (atIneq(r, c) != 0)
1628         sum++;
1629     }
1630     if (sum > 1)
1631       return false;
1632   }
1633   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1634     unsigned sum = 0;
1635     for (unsigned c = pos; c < pos + num; c++) {
1636       if (atEq(r, c) != 0)
1637         sum++;
1638     }
1639     if (sum > 1)
1640       return false;
1641   }
1642   return true;
1643 }
1644 
1645 /// Removes duplicate constraints, trivially true constraints, and constraints
1646 /// that can be detected as redundant as a result of differing only in their
1647 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
1648 /// considered trivially true.
1649 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
1650 //  remove duplicates in place.
1651 void IntegerRelation::removeTrivialRedundancy() {
1652   gcdTightenInequalities();
1653   normalizeConstraintsByGCD();
1654 
1655   // A map used to detect redundancy stemming from constraints that only differ
1656   // in their constant term. The value stored is <row position, const term>
1657   // for a given row.
1658   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
1659       rowsWithoutConstTerm;
1660   // To unique rows.
1661   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
1662 
1663   // Check if constraint is of the form <non-negative-constant> >= 0.
1664   auto isTriviallyValid = [&](unsigned r) -> bool {
1665     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
1666       if (atIneq(r, c) != 0)
1667         return false;
1668     }
1669     return atIneq(r, getNumCols() - 1) >= 0;
1670   };
1671 
1672   // Detect and mark redundant constraints.
1673   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
1674   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1675     int64_t *rowStart = &inequalities(r, 0);
1676     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
1677     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
1678       redunIneq[r] = true;
1679       continue;
1680     }
1681 
1682     // Among constraints that only differ in the constant term part, mark
1683     // everything other than the one with the smallest constant term redundant.
1684     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
1685     // former two are redundant).
1686     int64_t constTerm = atIneq(r, getNumCols() - 1);
1687     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
1688     const auto &ret =
1689         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
1690     if (!ret.second) {
1691       // Check if the other constraint has a higher constant term.
1692       auto &val = ret.first->second;
1693       if (val.second > constTerm) {
1694         // The stored row is redundant. Mark it so, and update with this one.
1695         redunIneq[val.first] = true;
1696         val = {r, constTerm};
1697       } else {
1698         // The one stored makes this one redundant.
1699         redunIneq[r] = true;
1700       }
1701     }
1702   }
1703 
1704   // Scan to get rid of all rows marked redundant, in-place.
1705   unsigned pos = 0;
1706   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1707     if (!redunIneq[r])
1708       inequalities.copyRow(r, pos++);
1709 
1710   inequalities.resizeVertically(pos);
1711 
1712   // TODO: consider doing this for equalities as well, but probably not worth
1713   // the savings.
1714 }
1715 
1716 #undef DEBUG_TYPE
1717 #define DEBUG_TYPE "fm"
1718 
1719 /// Eliminates variable at the specified position using Fourier-Motzkin
1720 /// variable elimination. This technique is exact for rational spaces but
1721 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
1722 /// to a projection operation yielding the (convex) set of integer points
1723 /// contained in the rational shadow of the set. An emptiness test that relies
1724 /// on this method will guarantee emptiness, i.e., it disproves the existence of
1725 /// a solution if it says it's empty.
1726 /// If a non-null isResultIntegerExact is passed, it is set to true if the
1727 /// result is also integer exact. If it's set to false, the obtained solution
1728 /// *may* not be exact, i.e., it may contain integer points that do not have an
1729 /// integer pre-image in the original set.
1730 ///
1731 /// Eg:
1732 /// j >= 0, j <= i + 1
1733 /// i >= 0, i <= N + 1
1734 /// Eliminating i yields,
1735 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
1736 ///
1737 /// If darkShadow = true, this method computes the dark shadow on elimination;
1738 /// the dark shadow is a convex integer subset of the exact integer shadow. A
1739 /// non-empty dark shadow proves the existence of an integer solution. The
1740 /// elimination in such a case could however be an under-approximation, and thus
1741 /// should not be used for scanning sets or used by itself for dependence
1742 /// checking.
1743 ///
1744 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
1745 ///            ^
1746 ///            |
1747 ///            | * * * * o o
1748 ///         i  | * * o o o o
1749 ///            | o * * * * *
1750 ///            --------------->
1751 ///                 j ->
1752 ///
1753 /// Eliminating i from this system (projecting on the j dimension):
1754 /// rational shadow / integer light shadow:  1 <= j <= 6
1755 /// dark shadow:                             3 <= j <= 6
1756 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
1757 /// holes/splinters:                         j = 2
1758 ///
1759 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
1760 // TODO: a slight modification to yield dark shadow version of FM (tightened),
1761 // which can prove the existence of a solution if there is one.
1762 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
1763                                               bool *isResultIntegerExact) {
1764   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
1765   LLVM_DEBUG(dump());
1766   assert(pos < getNumVars() && "invalid position");
1767   assert(hasConsistentState());
1768 
1769   // Check if this variable can be eliminated through a substitution.
1770   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1771     if (atEq(r, pos) != 0) {
1772       // Use Gaussian elimination here (since we have an equality).
1773       LogicalResult ret = gaussianEliminateVar(pos);
1774       (void)ret;
1775       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
1776       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
1777       LLVM_DEBUG(dump());
1778       return;
1779     }
1780   }
1781 
1782   // A fast linear time tightening.
1783   gcdTightenInequalities();
1784 
1785   // Check if the variable appears at all in any of the inequalities.
1786   if (isColZero(pos)) {
1787     // If it doesn't appear, just remove the column and return.
1788     // TODO: refactor removeColumns to use it from here.
1789     removeVar(pos);
1790     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1791     LLVM_DEBUG(dump());
1792     return;
1793   }
1794 
1795   // Positions of constraints that are lower bounds on the variable.
1796   SmallVector<unsigned, 4> lbIndices;
1797   // Positions of constraints that are lower bounds on the variable.
1798   SmallVector<unsigned, 4> ubIndices;
1799   // Positions of constraints that do not involve the variable.
1800   std::vector<unsigned> nbIndices;
1801   nbIndices.reserve(getNumInequalities());
1802 
1803   // Gather all lower bounds and upper bounds of the variable. Since the
1804   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1805   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1806   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1807     if (atIneq(r, pos) == 0) {
1808       // Var does not appear in bound.
1809       nbIndices.push_back(r);
1810     } else if (atIneq(r, pos) >= 1) {
1811       // Lower bound.
1812       lbIndices.push_back(r);
1813     } else {
1814       // Upper bound.
1815       ubIndices.push_back(r);
1816     }
1817   }
1818 
1819   PresburgerSpace newSpace = getSpace();
1820   VarKind idKindRemove = newSpace.getVarKindAt(pos);
1821   unsigned relativePos = pos - newSpace.getVarKindOffset(idKindRemove);
1822   newSpace.removeVarRange(idKindRemove, relativePos, relativePos + 1);
1823 
1824   /// Create the new system which has one variable less.
1825   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
1826                          getNumEqualities(), getNumCols() - 1, newSpace);
1827 
1828   // This will be used to check if the elimination was integer exact.
1829   unsigned lcmProducts = 1;
1830 
1831   // Let x be the variable we are eliminating.
1832   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
1833   // that c_l, c_u >= 1) we have:
1834   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
1835   // We thus generate a constraint:
1836   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
1837   // Note if c_l = c_u = 1, all integer points captured by the resulting
1838   // constraint correspond to integer points in the original system (i.e., they
1839   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
1840   // integer exact.
1841   for (auto ubPos : ubIndices) {
1842     for (auto lbPos : lbIndices) {
1843       SmallVector<int64_t, 4> ineq;
1844       ineq.reserve(newRel.getNumCols());
1845       int64_t lbCoeff = atIneq(lbPos, pos);
1846       // Note that in the comments above, ubCoeff is the negation of the
1847       // coefficient in the canonical form as the view taken here is that of the
1848       // term being moved to the other size of '>='.
1849       int64_t ubCoeff = -atIneq(ubPos, pos);
1850       // TODO: refactor this loop to avoid all branches inside.
1851       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1852         if (l == pos)
1853           continue;
1854         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
1855         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
1856         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
1857                        atIneq(lbPos, l) * (lcm / lbCoeff));
1858         lcmProducts *= lcm;
1859       }
1860       if (darkShadow) {
1861         // The dark shadow is a convex subset of the exact integer shadow. If
1862         // there is a point here, it proves the existence of a solution.
1863         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
1864       }
1865       // TODO: we need to have a way to add inequalities in-place in
1866       // IntegerRelation instead of creating and copying over.
1867       newRel.addInequality(ineq);
1868     }
1869   }
1870 
1871   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
1872                           << "\n");
1873   if (lcmProducts == 1 && isResultIntegerExact)
1874     *isResultIntegerExact = true;
1875 
1876   // Copy over the constraints not involving this variable.
1877   for (auto nbPos : nbIndices) {
1878     SmallVector<int64_t, 4> ineq;
1879     ineq.reserve(getNumCols() - 1);
1880     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1881       if (l == pos)
1882         continue;
1883       ineq.push_back(atIneq(nbPos, l));
1884     }
1885     newRel.addInequality(ineq);
1886   }
1887 
1888   assert(newRel.getNumConstraints() ==
1889          lbIndices.size() * ubIndices.size() + nbIndices.size());
1890 
1891   // Copy over the equalities.
1892   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1893     SmallVector<int64_t, 4> eq;
1894     eq.reserve(newRel.getNumCols());
1895     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1896       if (l == pos)
1897         continue;
1898       eq.push_back(atEq(r, l));
1899     }
1900     newRel.addEquality(eq);
1901   }
1902 
1903   // GCD tightening and normalization allows detection of more trivially
1904   // redundant constraints.
1905   newRel.gcdTightenInequalities();
1906   newRel.normalizeConstraintsByGCD();
1907   newRel.removeTrivialRedundancy();
1908   clearAndCopyFrom(newRel);
1909   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1910   LLVM_DEBUG(dump());
1911 }
1912 
1913 #undef DEBUG_TYPE
1914 #define DEBUG_TYPE "presburger"
1915 
1916 void IntegerRelation::projectOut(unsigned pos, unsigned num) {
1917   if (num == 0)
1918     return;
1919 
1920   // 'pos' can be at most getNumCols() - 2 if num > 0.
1921   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
1922   assert(pos + num < getNumCols() && "invalid range");
1923 
1924   // Eliminate as many variables as possible using Gaussian elimination.
1925   unsigned currentPos = pos;
1926   unsigned numToEliminate = num;
1927   unsigned numGaussianEliminated = 0;
1928 
1929   while (currentPos < getNumVars()) {
1930     unsigned curNumEliminated =
1931         gaussianEliminateVars(currentPos, currentPos + numToEliminate);
1932     ++currentPos;
1933     numToEliminate -= curNumEliminated + 1;
1934     numGaussianEliminated += curNumEliminated;
1935   }
1936 
1937   // Eliminate the remaining using Fourier-Motzkin.
1938   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
1939     unsigned numToEliminate = num - numGaussianEliminated - i;
1940     fourierMotzkinEliminate(
1941         getBestVarToEliminate(*this, pos, pos + numToEliminate));
1942   }
1943 
1944   // Fast/trivial simplifications.
1945   gcdTightenInequalities();
1946   // Normalize constraints after tightening since the latter impacts this, but
1947   // not the other way round.
1948   normalizeConstraintsByGCD();
1949 }
1950 
1951 namespace {
1952 
1953 enum BoundCmpResult { Greater, Less, Equal, Unknown };
1954 
1955 /// Compares two affine bounds whose coefficients are provided in 'first' and
1956 /// 'second'. The last coefficient is the constant term.
1957 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1958   assert(a.size() == b.size());
1959 
1960   // For the bounds to be comparable, their corresponding variable
1961   // coefficients should be equal; the constant terms are then compared to
1962   // determine less/greater/equal.
1963 
1964   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
1965     return Unknown;
1966 
1967   if (a.back() == b.back())
1968     return Equal;
1969 
1970   return a.back() < b.back() ? Less : Greater;
1971 }
1972 } // namespace
1973 
1974 // Returns constraints that are common to both A & B.
1975 static void getCommonConstraints(const IntegerRelation &a,
1976                                  const IntegerRelation &b, IntegerRelation &c) {
1977   c = IntegerRelation(a.getSpace());
1978   // a naive O(n^2) check should be enough here given the input sizes.
1979   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
1980     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
1981       if (a.getInequality(r) == b.getInequality(s)) {
1982         c.addInequality(a.getInequality(r));
1983         break;
1984       }
1985     }
1986   }
1987   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
1988     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
1989       if (a.getEquality(r) == b.getEquality(s)) {
1990         c.addEquality(a.getEquality(r));
1991         break;
1992       }
1993     }
1994   }
1995 }
1996 
1997 // Computes the bounding box with respect to 'other' by finding the min of the
1998 // lower bounds and the max of the upper bounds along each of the dimensions.
1999 LogicalResult
2000 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
2001   assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
2002   assert(getNumLocalVars() == 0 && "local ids not supported yet here");
2003 
2004   // Get the constraints common to both systems; these will be added as is to
2005   // the union.
2006   IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
2007   getCommonConstraints(*this, otherCst, commonCst);
2008 
2009   std::vector<SmallVector<int64_t, 8>> boundingLbs;
2010   std::vector<SmallVector<int64_t, 8>> boundingUbs;
2011   boundingLbs.reserve(2 * getNumDimVars());
2012   boundingUbs.reserve(2 * getNumDimVars());
2013 
2014   // To hold lower and upper bounds for each dimension.
2015   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2016   // To compute min of lower bounds and max of upper bounds for each dimension.
2017   SmallVector<int64_t, 4> minLb(getNumSymbolVars() + 1);
2018   SmallVector<int64_t, 4> maxUb(getNumSymbolVars() + 1);
2019   // To compute final new lower and upper bounds for the union.
2020   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2021 
2022   int64_t lbFloorDivisor, otherLbFloorDivisor;
2023   for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
2024     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2025     if (!extent.hasValue())
2026       // TODO: symbolic extents when necessary.
2027       // TODO: handle union if a dimension is unbounded.
2028       return failure();
2029 
2030     auto otherExtent = otherCst.getConstantBoundOnDimSize(
2031         d, &otherLb, &otherLbFloorDivisor, &otherUb);
2032     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2033       // TODO: symbolic extents when necessary.
2034       return failure();
2035 
2036     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2037 
2038     auto res = compareBounds(lb, otherLb);
2039     // Identify min.
2040     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2041       minLb = lb;
2042       // Since the divisor is for a floordiv, we need to convert to ceildiv,
2043       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2044       // div * i >= expr - div + 1.
2045       minLb.back() -= lbFloorDivisor - 1;
2046     } else if (res == BoundCmpResult::Greater) {
2047       minLb = otherLb;
2048       minLb.back() -= otherLbFloorDivisor - 1;
2049     } else {
2050       // Uncomparable - check for constant lower/upper bounds.
2051       auto constLb = getConstantBound(BoundType::LB, d);
2052       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
2053       if (!constLb.hasValue() || !constOtherLb.hasValue())
2054         return failure();
2055       std::fill(minLb.begin(), minLb.end(), 0);
2056       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2057     }
2058 
2059     // Do the same for ub's but max of upper bounds. Identify max.
2060     auto uRes = compareBounds(ub, otherUb);
2061     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2062       maxUb = ub;
2063     } else if (uRes == BoundCmpResult::Less) {
2064       maxUb = otherUb;
2065     } else {
2066       // Uncomparable - check for constant lower/upper bounds.
2067       auto constUb = getConstantBound(BoundType::UB, d);
2068       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
2069       if (!constUb.hasValue() || !constOtherUb.hasValue())
2070         return failure();
2071       std::fill(maxUb.begin(), maxUb.end(), 0);
2072       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
2073     }
2074 
2075     std::fill(newLb.begin(), newLb.end(), 0);
2076     std::fill(newUb.begin(), newUb.end(), 0);
2077 
2078     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2079     // and so it's the divisor for newLb and newUb as well.
2080     newLb[d] = lbFloorDivisor;
2081     newUb[d] = -lbFloorDivisor;
2082     // Copy over the symbolic part + constant term.
2083     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars());
2084     std::transform(newLb.begin() + getNumDimVars(), newLb.end(),
2085                    newLb.begin() + getNumDimVars(), std::negate<int64_t>());
2086     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars());
2087 
2088     boundingLbs.push_back(newLb);
2089     boundingUbs.push_back(newUb);
2090   }
2091 
2092   // Clear all constraints and add the lower/upper bounds for the bounding box.
2093   clearConstraints();
2094   for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
2095     addInequality(boundingLbs[d]);
2096     addInequality(boundingUbs[d]);
2097   }
2098 
2099   // Add the constraints that were common to both systems.
2100   append(commonCst);
2101   removeTrivialRedundancy();
2102 
2103   // TODO: copy over pure symbolic constraints from this and 'other' over to the
2104   // union (since the above are just the union along dimensions); we shouldn't
2105   // be discarding any other constraints on the symbols.
2106 
2107   return success();
2108 }
2109 
2110 bool IntegerRelation::isColZero(unsigned pos) const {
2111   unsigned rowPos;
2112   return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
2113          !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
2114 }
2115 
2116 /// Find positions of inequalities and equalities that do not have a coefficient
2117 /// for [pos, pos + num) variables.
2118 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
2119                                       unsigned num,
2120                                       SmallVectorImpl<unsigned> &nbIneqIndices,
2121                                       SmallVectorImpl<unsigned> &nbEqIndices) {
2122   assert(pos < cst.getNumVars() && "invalid start position");
2123   assert(pos + num <= cst.getNumVars() && "invalid limit");
2124 
2125   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2126     // The bounds are to be independent of [offset, offset + num) columns.
2127     unsigned c;
2128     for (c = pos; c < pos + num; ++c) {
2129       if (cst.atIneq(r, c) != 0)
2130         break;
2131     }
2132     if (c == pos + num)
2133       nbIneqIndices.push_back(r);
2134   }
2135 
2136   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2137     // The bounds are to be independent of [offset, offset + num) columns.
2138     unsigned c;
2139     for (c = pos; c < pos + num; ++c) {
2140       if (cst.atEq(r, c) != 0)
2141         break;
2142     }
2143     if (c == pos + num)
2144       nbEqIndices.push_back(r);
2145   }
2146 }
2147 
2148 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
2149   assert(pos + num <= getNumVars() && "invalid range");
2150 
2151   // Remove constraints that are independent of these variables.
2152   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2153   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2154 
2155   // Iterate in reverse so that indices don't have to be updated.
2156   // TODO: This method can be made more efficient (because removal of each
2157   // inequality leads to much shifting/copying in the underlying buffer).
2158   for (auto nbIndex : llvm::reverse(nbIneqIndices))
2159     removeInequality(nbIndex);
2160   for (auto nbIndex : llvm::reverse(nbEqIndices))
2161     removeEquality(nbIndex);
2162 }
2163 
2164 IntegerPolyhedron IntegerRelation::getDomainSet() const {
2165   IntegerRelation copyRel = *this;
2166 
2167   // Convert Range variables to Local variables.
2168   copyRel.convertVarKind(VarKind::Range, 0, getNumVarKind(VarKind::Range),
2169                          VarKind::Local);
2170 
2171   // Convert Domain variables to SetDim(Range) variables.
2172   copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
2173                          VarKind::SetDim);
2174 
2175   return IntegerPolyhedron(std::move(copyRel));
2176 }
2177 
2178 IntegerPolyhedron IntegerRelation::getRangeSet() const {
2179   IntegerRelation copyRel = *this;
2180 
2181   // Convert Domain variables to Local variables.
2182   copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
2183                          VarKind::Local);
2184 
2185   // We do not need to do anything to Range variables since they are already in
2186   // SetDim position.
2187 
2188   return IntegerPolyhedron(std::move(copyRel));
2189 }
2190 
2191 void IntegerRelation::intersectDomain(const IntegerPolyhedron &poly) {
2192   assert(getDomainSet().getSpace().isCompatible(poly.getSpace()) &&
2193          "Domain set is not compatible with poly");
2194 
2195   // Treating the poly as a relation, convert it from `0 -> R` to `R -> 0`.
2196   IntegerRelation rel = poly;
2197   rel.inverse();
2198 
2199   // Append dummy range variables to make the spaces compatible.
2200   rel.appendVar(VarKind::Range, getNumRangeVars());
2201 
2202   // Intersect in place.
2203   mergeLocalVars(rel);
2204   append(rel);
2205 }
2206 
2207 void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) {
2208   assert(getRangeSet().getSpace().isCompatible(poly.getSpace()) &&
2209          "Range set is not compatible with poly");
2210 
2211   IntegerRelation rel = poly;
2212 
2213   // Append dummy domain variables to make the spaces compatible.
2214   rel.appendVar(VarKind::Domain, getNumDomainVars());
2215 
2216   mergeLocalVars(rel);
2217   append(rel);
2218 }
2219 
2220 void IntegerRelation::inverse() {
2221   unsigned numRangeVars = getNumVarKind(VarKind::Range);
2222   convertVarKind(VarKind::Domain, 0, getVarKindEnd(VarKind::Domain),
2223                  VarKind::Range);
2224   convertVarKind(VarKind::Range, 0, numRangeVars, VarKind::Domain);
2225 }
2226 
2227 void IntegerRelation::compose(const IntegerRelation &rel) {
2228   assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) &&
2229          "Range of `this` should be compatible with Domain of `rel`");
2230 
2231   IntegerRelation copyRel = rel;
2232 
2233   // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C.
2234   // We convert R1 to A -> (B X C), and R2 to B X C then intersect the range of
2235   // R1 with R2. After this, we get R1: A -> C, by projecting out B.
2236   // TODO: Using nested spaces here would help, since we could directly
2237   // intersect the range with another relation.
2238   unsigned numBVars = getNumRangeVars();
2239 
2240   // Convert R1 from A -> B to A -> (B X C).
2241   appendVar(VarKind::Range, copyRel.getNumRangeVars());
2242 
2243   // Convert R2 to B X C.
2244   copyRel.convertVarKind(VarKind::Domain, 0, numBVars, VarKind::Range, 0);
2245 
2246   // Intersect R2 to range of R1.
2247   intersectRange(IntegerPolyhedron(copyRel));
2248 
2249   // Project out B in R1.
2250   convertVarKind(VarKind::Range, 0, numBVars, VarKind::Local);
2251 }
2252 
2253 void IntegerRelation::applyDomain(const IntegerRelation &rel) {
2254   inverse();
2255   compose(rel);
2256   inverse();
2257 }
2258 
2259 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
2260 
2261 void IntegerRelation::printSpace(raw_ostream &os) const {
2262   space.print(os);
2263   os << getNumConstraints() << " constraints\n";
2264 }
2265 
2266 void IntegerRelation::print(raw_ostream &os) const {
2267   assert(hasConsistentState());
2268   printSpace(os);
2269   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2270     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2271       os << atEq(i, j) << " ";
2272     }
2273     os << "= 0\n";
2274   }
2275   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2276     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2277       os << atIneq(i, j) << " ";
2278     }
2279     os << ">= 0\n";
2280   }
2281   os << '\n';
2282 }
2283 
2284 void IntegerRelation::dump() const { print(llvm::errs()); }
2285 
2286 unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
2287                                       unsigned num) {
2288   assert((kind != VarKind::Domain || num == 0) &&
2289          "Domain has to be zero in a set");
2290   return IntegerRelation::insertVar(kind, pos, num);
2291 }
2292 IntegerPolyhedron
2293 IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const {
2294   return IntegerPolyhedron(IntegerRelation::intersect(other));
2295 }
2296 
2297 PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const {
2298   return PresburgerSet(IntegerRelation::subtract(other));
2299 }
2300