1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A class to represent an relation over integer tuples. A relation is
10 // represented as a constraint system over a space of tuples of integer valued
11 // variables supporting symbolic variables and existential quantification.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/Presburger/IntegerRelation.h"
16 #include "mlir/Analysis/Presburger/LinearTransform.h"
17 #include "mlir/Analysis/Presburger/PWMAFunction.h"
18 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
19 #include "mlir/Analysis/Presburger/Simplex.h"
20 #include "mlir/Analysis/Presburger/Utils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "presburger"
26 
27 using namespace mlir;
28 using namespace presburger;
29 
30 using llvm::SmallDenseMap;
31 using llvm::SmallDenseSet;
32 
clone() const33 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
34   return std::make_unique<IntegerRelation>(*this);
35 }
36 
clone() const37 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
38   return std::make_unique<IntegerPolyhedron>(*this);
39 }
40 
setSpace(const PresburgerSpace & oSpace)41 void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
42   assert(space.getNumVars() == oSpace.getNumVars() && "invalid space!");
43   space = oSpace;
44 }
45 
setSpaceExceptLocals(const PresburgerSpace & oSpace)46 void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
47   assert(oSpace.getNumLocalVars() == 0 && "no locals should be present!");
48   assert(oSpace.getNumVars() <= getNumVars() && "invalid space!");
49   unsigned newNumLocals = getNumVars() - oSpace.getNumVars();
50   space = oSpace;
51   space.insertVar(VarKind::Local, 0, newNumLocals);
52 }
53 
append(const IntegerRelation & other)54 void IntegerRelation::append(const IntegerRelation &other) {
55   assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
56 
57   inequalities.reserveRows(inequalities.getNumRows() +
58                            other.getNumInequalities());
59   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
60 
61   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
62     addInequality(other.getInequality(r));
63   }
64   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
65     addEquality(other.getEquality(r));
66   }
67 }
68 
intersect(IntegerRelation other) const69 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const {
70   IntegerRelation result = *this;
71   result.mergeLocalVars(other);
72   result.append(other);
73   return result;
74 }
75 
isEqual(const IntegerRelation & other) const76 bool IntegerRelation::isEqual(const IntegerRelation &other) const {
77   assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
78   return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
79 }
80 
isSubsetOf(const IntegerRelation & other) const81 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
82   assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
83   return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other));
84 }
85 
86 MaybeOptimum<SmallVector<Fraction, 8>>
findRationalLexMin() const87 IntegerRelation::findRationalLexMin() const {
88   assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
89   MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
90       LexSimplex(*this).findRationalLexMin();
91 
92   if (!maybeLexMin.isBounded())
93     return maybeLexMin;
94 
95   // The Simplex returns the lexmin over all the variables including locals. But
96   // locals are not actually part of the space and should not be returned in the
97   // result. Since the locals are placed last in the list of variables, they
98   // will be minimized last in the lexmin. So simply truncating out the locals
99   // from the end of the answer gives the desired lexmin over the dimensions.
100   assert(maybeLexMin->size() == getNumVars() &&
101          "Incorrect number of vars in lexMin!");
102   maybeLexMin->resize(getNumDimAndSymbolVars());
103   return maybeLexMin;
104 }
105 
106 MaybeOptimum<SmallVector<int64_t, 8>>
findIntegerLexMin() const107 IntegerRelation::findIntegerLexMin() const {
108   assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
109   MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
110       LexSimplex(*this).findIntegerLexMin();
111 
112   if (!maybeLexMin.isBounded())
113     return maybeLexMin.getKind();
114 
115   // The Simplex returns the lexmin over all the variables including locals. But
116   // locals are not actually part of the space and should not be returned in the
117   // result. Since the locals are placed last in the list of variables, they
118   // will be minimized last in the lexmin. So simply truncating out the locals
119   // from the end of the answer gives the desired lexmin over the dimensions.
120   assert(maybeLexMin->size() == getNumVars() &&
121          "Incorrect number of vars in lexMin!");
122   maybeLexMin->resize(getNumDimAndSymbolVars());
123   return maybeLexMin;
124 }
125 
rangeIsZero(ArrayRef<int64_t> range)126 static bool rangeIsZero(ArrayRef<int64_t> range) {
127   return llvm::all_of(range, [](int64_t x) { return x == 0; });
128 }
129 
removeConstraintsInvolvingVarRange(IntegerRelation & poly,unsigned begin,unsigned count)130 static void removeConstraintsInvolvingVarRange(IntegerRelation &poly,
131                                                unsigned begin, unsigned count) {
132   // We loop until i > 0 and index into i - 1 to avoid sign issues.
133   //
134   // We iterate backwards so that whether we remove constraint i - 1 or not, the
135   // next constraint to be tested is always i - 2.
136   for (unsigned i = poly.getNumEqualities(); i > 0; i--)
137     if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count)))
138       poly.removeEquality(i - 1);
139   for (unsigned i = poly.getNumInequalities(); i > 0; i--)
140     if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
141       poly.removeInequality(i - 1);
142 }
143 
getCounts() const144 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
145   return {getSpace(), getNumInequalities(), getNumEqualities()};
146 }
147 
truncateVarKind(VarKind kind,unsigned num)148 void IntegerRelation::truncateVarKind(VarKind kind, unsigned num) {
149   unsigned curNum = getNumVarKind(kind);
150   assert(num <= curNum && "Can't truncate to more vars!");
151   removeVarRange(kind, num, curNum);
152 }
153 
truncateVarKind(VarKind kind,const CountsSnapshot & counts)154 void IntegerRelation::truncateVarKind(VarKind kind,
155                                       const CountsSnapshot &counts) {
156   truncateVarKind(kind, counts.getSpace().getNumVarKind(kind));
157 }
158 
truncate(const CountsSnapshot & counts)159 void IntegerRelation::truncate(const CountsSnapshot &counts) {
160   truncateVarKind(VarKind::Domain, counts);
161   truncateVarKind(VarKind::Range, counts);
162   truncateVarKind(VarKind::Symbol, counts);
163   truncateVarKind(VarKind::Local, counts);
164   removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
165   removeEqualityRange(counts.getNumEqs(), getNumEqualities());
166 }
167 
computeReprWithOnlyDivLocals() const168 PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
169   // If there are no locals, we're done.
170   if (getNumLocalVars() == 0)
171     return PresburgerRelation(*this);
172 
173   // Move all the non-div locals to the end, as the current API to
174   // SymbolicLexMin requires these to form a contiguous range.
175   //
176   // Take a copy so we can perform mutations.
177   IntegerRelation copy = *this;
178   std::vector<MaybeLocalRepr> reprs(getNumLocalVars());
179   copy.getLocalReprs(&reprs);
180 
181   // Iterate through all the locals. The last `numNonDivLocals` are the locals
182   // that have been scanned already and do not have division representations.
183   unsigned numNonDivLocals = 0;
184   unsigned offset = copy.getVarKindOffset(VarKind::Local);
185   for (unsigned i = 0, e = copy.getNumLocalVars(); i < e - numNonDivLocals;) {
186     if (!reprs[i]) {
187       // Whenever we come across a local that does not have a division
188       // representation, we swap it to the `numNonDivLocals`-th last position
189       // and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
190       copy.swapVar(offset + i, offset + e - numNonDivLocals - 1);
191       std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
192       ++numNonDivLocals;
193       continue;
194     }
195     ++i;
196   }
197 
198   // If there are no non-div locals, we're done.
199   if (numNonDivLocals == 0)
200     return PresburgerRelation(*this);
201 
202   // We computeSymbolicIntegerLexMin by considering the non-div locals as
203   // "non-symbols" and considering everything else as "symbols". This will
204   // compute a function mapping assignments to "symbols" to the
205   // lexicographically minimal valid assignment of "non-symbols", when a
206   // satisfying assignment exists. It separately returns the set of assignments
207   // to the "symbols" such that a satisfying assignment to the "non-symbols"
208   // exists but the lexmin is unbounded. We basically want to find the set of
209   // values of the "symbols" such that an assignment to the "non-symbols"
210   // exists, which is the union of the domain of the returned lexmin function
211   // and the returned set of assignments to the "symbols" that makes the lexmin
212   // unbounded.
213   SymbolicLexMin lexminResult =
214       SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
215                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
216                              /*numDims=*/copy.getNumVars() - numNonDivLocals)))
217           .computeSymbolicIntegerLexMin();
218   PresburgerRelation result =
219       lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
220 
221   // The result set might lie in the wrong space -- all its ids are dims.
222   // Set it to the desired space and return.
223   PresburgerSpace space = getSpace();
224   space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
225   result.setSpace(space);
226   return result;
227 }
228 
findSymbolicIntegerLexMin() const229 SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
230   // Symbol and Domain vars will be used as symbols for symbolic lexmin.
231   // In other words, for every value of the symbols and domain, return the
232   // lexmin value of the (range, locals).
233   llvm::SmallBitVector isSymbol(getNumVars(), false);
234   isSymbol.set(getVarKindOffset(VarKind::Symbol),
235                getVarKindEnd(VarKind::Symbol));
236   isSymbol.set(getVarKindOffset(VarKind::Domain),
237                getVarKindEnd(VarKind::Domain));
238   // Compute the symbolic lexmin of the dims and locals, with the symbols being
239   // the actual symbols of this set.
240   SymbolicLexMin result =
241       SymbolicLexSimplex(*this,
242                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
243                              /*numDims=*/getNumDomainVars(),
244                              /*numSymbols=*/getNumSymbolVars())),
245                          isSymbol)
246           .computeSymbolicIntegerLexMin();
247 
248   // We want to return only the lexmin over the dims, so strip the locals from
249   // the computed lexmin.
250   result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
251                                getNumLocalVars());
252   return result;
253 }
254 
255 PresburgerRelation
subtract(const PresburgerRelation & set) const256 IntegerRelation::subtract(const PresburgerRelation &set) const {
257   return PresburgerRelation(*this).subtract(set);
258 }
259 
insertVar(VarKind kind,unsigned pos,unsigned num)260 unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) {
261   assert(pos <= getNumVarKind(kind));
262 
263   unsigned insertPos = space.insertVar(kind, pos, num);
264   inequalities.insertColumns(insertPos, num);
265   equalities.insertColumns(insertPos, num);
266   return insertPos;
267 }
268 
appendVar(VarKind kind,unsigned num)269 unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) {
270   unsigned pos = getNumVarKind(kind);
271   return insertVar(kind, pos, num);
272 }
273 
addEquality(ArrayRef<int64_t> eq)274 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
275   assert(eq.size() == getNumCols());
276   unsigned row = equalities.appendExtraRow();
277   for (unsigned i = 0, e = eq.size(); i < e; ++i)
278     equalities(row, i) = eq[i];
279 }
280 
addInequality(ArrayRef<int64_t> inEq)281 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
282   assert(inEq.size() == getNumCols());
283   unsigned row = inequalities.appendExtraRow();
284   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
285     inequalities(row, i) = inEq[i];
286 }
287 
removeVar(VarKind kind,unsigned pos)288 void IntegerRelation::removeVar(VarKind kind, unsigned pos) {
289   removeVarRange(kind, pos, pos + 1);
290 }
291 
removeVar(unsigned pos)292 void IntegerRelation::removeVar(unsigned pos) { removeVarRange(pos, pos + 1); }
293 
removeVarRange(VarKind kind,unsigned varStart,unsigned varLimit)294 void IntegerRelation::removeVarRange(VarKind kind, unsigned varStart,
295                                      unsigned varLimit) {
296   assert(varLimit <= getNumVarKind(kind));
297 
298   if (varStart >= varLimit)
299     return;
300 
301   // Remove eliminated variables from the constraints.
302   unsigned offset = getVarKindOffset(kind);
303   equalities.removeColumns(offset + varStart, varLimit - varStart);
304   inequalities.removeColumns(offset + varStart, varLimit - varStart);
305 
306   // Remove eliminated variables from the space.
307   space.removeVarRange(kind, varStart, varLimit);
308 }
309 
removeVarRange(unsigned varStart,unsigned varLimit)310 void IntegerRelation::removeVarRange(unsigned varStart, unsigned varLimit) {
311   assert(varLimit <= getNumVars());
312 
313   if (varStart >= varLimit)
314     return;
315 
316   // Helper function to remove vars of the specified kind in the given range
317   // [start, limit), The range is absolute (i.e. it is not relative to the kind
318   // of variable). Also updates `limit` to reflect the deleted variables.
319   auto removeVarKindInRange = [this](VarKind kind, unsigned &start,
320                                      unsigned &limit) {
321     if (start >= limit)
322       return;
323 
324     unsigned offset = getVarKindOffset(kind);
325     unsigned num = getNumVarKind(kind);
326 
327     // Get `start`, `limit` relative to the specified kind.
328     unsigned relativeStart =
329         start <= offset ? 0 : std::min(num, start - offset);
330     unsigned relativeLimit =
331         limit <= offset ? 0 : std::min(num, limit - offset);
332 
333     // Remove vars of the specified kind in the relative range.
334     removeVarRange(kind, relativeStart, relativeLimit);
335 
336     // Update `limit` to reflect deleted variables.
337     // `start` does not need to be updated because any variables that are
338     // deleted are after position `start`.
339     limit -= relativeLimit - relativeStart;
340   };
341 
342   removeVarKindInRange(VarKind::Domain, varStart, varLimit);
343   removeVarKindInRange(VarKind::Range, varStart, varLimit);
344   removeVarKindInRange(VarKind::Symbol, varStart, varLimit);
345   removeVarKindInRange(VarKind::Local, varStart, varLimit);
346 }
347 
removeEquality(unsigned pos)348 void IntegerRelation::removeEquality(unsigned pos) {
349   equalities.removeRow(pos);
350 }
351 
removeInequality(unsigned pos)352 void IntegerRelation::removeInequality(unsigned pos) {
353   inequalities.removeRow(pos);
354 }
355 
removeEqualityRange(unsigned start,unsigned end)356 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
357   if (start >= end)
358     return;
359   equalities.removeRows(start, end - start);
360 }
361 
removeInequalityRange(unsigned start,unsigned end)362 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
363   if (start >= end)
364     return;
365   inequalities.removeRows(start, end - start);
366 }
367 
swapVar(unsigned posA,unsigned posB)368 void IntegerRelation::swapVar(unsigned posA, unsigned posB) {
369   assert(posA < getNumVars() && "invalid position A");
370   assert(posB < getNumVars() && "invalid position B");
371 
372   if (posA == posB)
373     return;
374 
375   inequalities.swapColumns(posA, posB);
376   equalities.swapColumns(posA, posB);
377 }
378 
clearConstraints()379 void IntegerRelation::clearConstraints() {
380   equalities.resizeVertically(0);
381   inequalities.resizeVertically(0);
382 }
383 
384 /// Gather all lower and upper bounds of the variable at `pos`, and
385 /// optionally any equalities on it. In addition, the bounds are to be
386 /// independent of variables in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const387 void IntegerRelation::getLowerAndUpperBoundIndices(
388     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
389     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
390     unsigned offset, unsigned num) const {
391   assert(pos < getNumVars() && "invalid position");
392   assert(offset + num < getNumCols() && "invalid range");
393 
394   // Checks for a constraint that has a non-zero coeff for the variables in
395   // the position range [offset, offset + num) while ignoring `pos`.
396   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
397     unsigned c, f;
398     auto cst = isEq ? getEquality(r) : getInequality(r);
399     for (c = offset, f = offset + num; c < f; ++c) {
400       if (c == pos)
401         continue;
402       if (cst[c] != 0)
403         break;
404     }
405     return c < f;
406   };
407 
408   // Gather all lower bounds and upper bounds of the variable. Since the
409   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
410   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
411   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
412     // The bounds are to be independent of [offset, offset + num) columns.
413     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
414       continue;
415     if (atIneq(r, pos) >= 1) {
416       // Lower bound.
417       lbIndices->push_back(r);
418     } else if (atIneq(r, pos) <= -1) {
419       // Upper bound.
420       ubIndices->push_back(r);
421     }
422   }
423 
424   // An equality is both a lower and upper bound. Record any equalities
425   // involving the pos^th variable.
426   if (!eqIndices)
427     return;
428 
429   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
430     if (atEq(r, pos) == 0)
431       continue;
432     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
433       continue;
434     eqIndices->push_back(r);
435   }
436 }
437 
hasConsistentState() const438 bool IntegerRelation::hasConsistentState() const {
439   if (!inequalities.hasConsistentState())
440     return false;
441   if (!equalities.hasConsistentState())
442     return false;
443   return true;
444 }
445 
setAndEliminate(unsigned pos,ArrayRef<int64_t> values)446 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
447   if (values.empty())
448     return;
449   assert(pos + values.size() <= getNumVars() &&
450          "invalid position or too many values");
451   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
452   // constant term and removing the var x_j. We do this for all the vars
453   // pos, pos + 1, ... pos + values.size() - 1.
454   unsigned constantColPos = getNumCols() - 1;
455   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
456     inequalities.addToColumn(i + pos, constantColPos, values[i]);
457   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
458     equalities.addToColumn(i + pos, constantColPos, values[i]);
459   removeVarRange(pos, pos + values.size());
460 }
461 
clearAndCopyFrom(const IntegerRelation & other)462 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
463   *this = other;
464 }
465 
466 // Searches for a constraint with a non-zero coefficient at `colIdx` in
467 // equality (isEq=true) or inequality (isEq=false) constraints.
468 // Returns true and sets row found in search in `rowIdx`, false otherwise.
findConstraintWithNonZeroAt(unsigned colIdx,bool isEq,unsigned * rowIdx) const469 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
470                                                   unsigned *rowIdx) const {
471   assert(colIdx < getNumCols() && "position out of bounds");
472   auto at = [&](unsigned rowIdx) -> int64_t {
473     return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
474   };
475   unsigned e = isEq ? getNumEqualities() : getNumInequalities();
476   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
477     if (at(*rowIdx) != 0) {
478       return true;
479     }
480   }
481   return false;
482 }
483 
normalizeConstraintsByGCD()484 void IntegerRelation::normalizeConstraintsByGCD() {
485   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
486     equalities.normalizeRow(i);
487   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
488     inequalities.normalizeRow(i);
489 }
490 
hasInvalidConstraint() const491 bool IntegerRelation::hasInvalidConstraint() const {
492   assert(hasConsistentState());
493   auto check = [&](bool isEq) -> bool {
494     unsigned numCols = getNumCols();
495     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
496     for (unsigned i = 0, e = numRows; i < e; ++i) {
497       unsigned j;
498       for (j = 0; j < numCols - 1; ++j) {
499         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
500         // Skip rows with non-zero variable coefficients.
501         if (v != 0)
502           break;
503       }
504       if (j < numCols - 1) {
505         continue;
506       }
507       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
508       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
509       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
510       if ((isEq && v != 0) || (!isEq && v < 0)) {
511         return true;
512       }
513     }
514     return false;
515   };
516   if (check(/*isEq=*/true))
517     return true;
518   return check(/*isEq=*/false);
519 }
520 
521 /// Eliminate variable from constraint at `rowIdx` based on coefficient at
522 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
523 /// updated as they have already been eliminated.
eliminateFromConstraint(IntegerRelation * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)524 static void eliminateFromConstraint(IntegerRelation *constraints,
525                                     unsigned rowIdx, unsigned pivotRow,
526                                     unsigned pivotCol, unsigned elimColStart,
527                                     bool isEq) {
528   // Skip if equality 'rowIdx' if same as 'pivotRow'.
529   if (isEq && rowIdx == pivotRow)
530     return;
531   auto at = [&](unsigned i, unsigned j) -> int64_t {
532     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
533   };
534   int64_t leadCoeff = at(rowIdx, pivotCol);
535   // Skip if leading coefficient at 'rowIdx' is already zero.
536   if (leadCoeff == 0)
537     return;
538   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
539   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
540   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
541   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
542   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
543 
544   unsigned numCols = constraints->getNumCols();
545   for (unsigned j = 0; j < numCols; ++j) {
546     // Skip updating column 'j' if it was just eliminated.
547     if (j >= elimColStart && j < pivotCol)
548       continue;
549     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
550                 rowMultiplier * at(rowIdx, j);
551     isEq ? constraints->atEq(rowIdx, j) = v
552          : constraints->atIneq(rowIdx, j) = v;
553   }
554 }
555 
556 /// Returns the position of the variable that has the minimum <number of lower
557 /// bounds> times <number of upper bounds> from the specified range of
558 /// variables [start, end). It is often best to eliminate in the increasing
559 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
560 /// that many new constraints.
getBestVarToEliminate(const IntegerRelation & cst,unsigned start,unsigned end)561 static unsigned getBestVarToEliminate(const IntegerRelation &cst,
562                                       unsigned start, unsigned end) {
563   assert(start < cst.getNumVars() && end < cst.getNumVars() + 1);
564 
565   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
566     unsigned numLb = 0;
567     unsigned numUb = 0;
568     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
569       if (cst.atIneq(r, pos) > 0) {
570         ++numLb;
571       } else if (cst.atIneq(r, pos) < 0) {
572         ++numUb;
573       }
574     }
575     return numLb * numUb;
576   };
577 
578   unsigned minLoc = start;
579   unsigned min = getProductOfNumLowerUpperBounds(start);
580   for (unsigned c = start + 1; c < end; c++) {
581     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
582     if (numLbUbProduct < min) {
583       min = numLbUbProduct;
584       minLoc = c;
585     }
586   }
587   return minLoc;
588 }
589 
590 // Checks for emptiness of the set by eliminating variables successively and
591 // using the GCD test (on all equality constraints) and checking for trivially
592 // invalid constraints. Returns 'true' if the constraint system is found to be
593 // empty; false otherwise.
isEmpty() const594 bool IntegerRelation::isEmpty() const {
595   if (isEmptyByGCDTest() || hasInvalidConstraint())
596     return true;
597 
598   IntegerRelation tmpCst(*this);
599 
600   // First, eliminate as many local variables as possible using equalities.
601   tmpCst.removeRedundantLocalVars();
602   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
603     return true;
604 
605   // Eliminate as many variables as possible using Gaussian elimination.
606   unsigned currentPos = 0;
607   while (currentPos < tmpCst.getNumVars()) {
608     tmpCst.gaussianEliminateVars(currentPos, tmpCst.getNumVars());
609     ++currentPos;
610     // We check emptiness through trivial checks after eliminating each ID to
611     // detect emptiness early. Since the checks isEmptyByGCDTest() and
612     // hasInvalidConstraint() are linear time and single sweep on the constraint
613     // buffer, this appears reasonable - but can optimize in the future.
614     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
615       return true;
616   }
617 
618   // Eliminate the remaining using FM.
619   for (unsigned i = 0, e = tmpCst.getNumVars(); i < e; i++) {
620     tmpCst.fourierMotzkinEliminate(
621         getBestVarToEliminate(tmpCst, 0, tmpCst.getNumVars()));
622     // Check for a constraint explosion. This rarely happens in practice, but
623     // this check exists as a safeguard against improperly constructed
624     // constraint systems or artificially created arbitrarily complex systems
625     // that aren't the intended use case for IntegerRelation. This is
626     // needed since FM has a worst case exponential complexity in theory.
627     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumVars()) {
628       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
629       return false;
630     }
631 
632     // FM wouldn't have modified the equalities in any way. So no need to again
633     // run GCD test. Check for trivial invalid constraints.
634     if (tmpCst.hasInvalidConstraint())
635       return true;
636   }
637   return false;
638 }
639 
640 // Runs the GCD test on all equality constraints. Returns 'true' if this test
641 // fails on any equality. Returns 'false' otherwise.
642 // This test can be used to disprove the existence of a solution. If it returns
643 // true, no integer solution to the equality constraints can exist.
644 //
645 // GCD test definition:
646 //
647 // The equality constraint:
648 //
649 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
650 //
651 // has an integer solution iff:
652 //
653 //  GCD of c_1, c_2, ..., c_n divides c_0.
654 //
isEmptyByGCDTest() const655 bool IntegerRelation::isEmptyByGCDTest() const {
656   assert(hasConsistentState());
657   unsigned numCols = getNumCols();
658   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
659     uint64_t gcd = std::abs(atEq(i, 0));
660     for (unsigned j = 1; j < numCols - 1; ++j) {
661       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
662     }
663     int64_t v = std::abs(atEq(i, numCols - 1));
664     if (gcd > 0 && (v % gcd != 0)) {
665       return true;
666     }
667   }
668   return false;
669 }
670 
671 // Returns a matrix where each row is a vector along which the polytope is
672 // bounded. The span of the returned vectors is guaranteed to contain all
673 // such vectors. The returned vectors are NOT guaranteed to be linearly
674 // independent. This function should not be called on empty sets.
675 //
676 // It is sufficient to check the perpendiculars of the constraints, as the set
677 // of perpendiculars which are bounded must span all bounded directions.
getBoundedDirections() const678 Matrix IntegerRelation::getBoundedDirections() const {
679   // Note that it is necessary to add the equalities too (which the constructor
680   // does) even though we don't need to check if they are bounded; whether an
681   // inequality is bounded or not depends on what other constraints, including
682   // equalities, are present.
683   Simplex simplex(*this);
684 
685   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
686                                "direction is bounded in an empty set.");
687 
688   SmallVector<unsigned, 8> boundedIneqs;
689   // The constructor adds the inequalities to the simplex first, so this
690   // processes all the inequalities.
691   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
692     if (simplex.isBoundedAlongConstraint(i))
693       boundedIneqs.push_back(i);
694   }
695 
696   // The direction vector is given by the coefficients and does not include the
697   // constant term, so the matrix has one fewer column.
698   unsigned dirsNumCols = getNumCols() - 1;
699   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
700 
701   // Copy the bounded inequalities.
702   unsigned row = 0;
703   for (unsigned i : boundedIneqs) {
704     for (unsigned col = 0; col < dirsNumCols; ++col)
705       dirs(row, col) = atIneq(i, col);
706     ++row;
707   }
708 
709   // Copy the equalities. All the equalities' perpendiculars are bounded.
710   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
711     for (unsigned col = 0; col < dirsNumCols; ++col)
712       dirs(row, col) = atEq(i, col);
713     ++row;
714   }
715 
716   return dirs;
717 }
718 
isIntegerEmpty() const719 bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); }
720 
721 /// Let this set be S. If S is bounded then we directly call into the GBR
722 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
723 /// vectors v such that S extends to infinity along v or -v. In this case we
724 /// use an algorithm described in the integer set library (isl) manual and used
725 /// by the isl_set_sample function in that library. The algorithm is:
726 ///
727 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
728 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
729 /// dimensions.
730 ///
731 /// 2) Construct a set B by removing all constraints that involve
732 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
733 /// that B is a Bounded set.
734 ///
735 /// 3) Try to obtain a sample from B using the GBR sampling
736 /// algorithm. If no sample is found, return that S is empty.
737 ///
738 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
739 /// C. C is a full-dimensional Cone and always contains a sample.
740 ///
741 /// 5) Obtain an integer sample from C.
742 ///
743 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
744 ///
745 /// The following is a sketch of a proof that
746 /// a) If the algorithm returns empty, then S is empty.
747 /// b) If the algorithm returns a sample, it is a valid sample in S.
748 ///
749 /// The algorithm returns empty only if B is empty, in which case S*T is
750 /// certainly empty since B was obtained by removing constraints and then
751 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
752 /// v is in S*T iff T*v is in S. So in this case, since
753 /// S*T is empty, S is empty too.
754 ///
755 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
756 /// constraints of S*T that did not involve unbounded dimensions are satisfied
757 /// by this substitution. All dimensions in the linear span of the dimensions
758 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
759 /// the bounded dimensions cannot make these dimensions bounded, and these are
760 /// the only remaining dimensions in C, so C is unbounded along every vector (in
761 /// the positive or negative direction, or both). C is hence a full-dimensional
762 /// cone and therefore always contains an integer point.
763 ///
764 /// Concatenating the samples from B and C gives a sample v in S*T, so the
765 /// returned sample T*v is a sample in S.
findIntegerSample() const766 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
767   // First, try the GCD test heuristic.
768   if (isEmptyByGCDTest())
769     return {};
770 
771   Simplex simplex(*this);
772   if (simplex.isEmpty())
773     return {};
774 
775   // For a bounded set, we directly call into the GBR sampling algorithm.
776   if (!simplex.isUnbounded())
777     return simplex.findIntegerSample();
778 
779   // The set is unbounded. We cannot directly use the GBR algorithm.
780   //
781   // m is a matrix containing, in each row, a vector in which S is
782   // bounded, such that the linear span of all these dimensions contains all
783   // bounded dimensions in S.
784   Matrix m = getBoundedDirections();
785   // In column echelon form, each row of m occupies only the first rank(m)
786   // columns and has zeros on the other columns. The transform T that brings S
787   // to column echelon form is unimodular as well, so this is a suitable
788   // transform to use in step 1 of the algorithm.
789   std::pair<unsigned, LinearTransform> result =
790       LinearTransform::makeTransformToColumnEchelon(std::move(m));
791   const LinearTransform &transform = result.second;
792   // 1) Apply T to S to obtain S*T.
793   IntegerRelation transformedSet = transform.applyTo(*this);
794 
795   // 2) Remove the unbounded dimensions and constraints involving them to
796   // obtain a bounded set.
797   IntegerRelation boundedSet(transformedSet);
798   unsigned numBoundedDims = result.first;
799   unsigned numUnboundedDims = getNumVars() - numBoundedDims;
800   removeConstraintsInvolvingVarRange(boundedSet, numBoundedDims,
801                                      numUnboundedDims);
802   boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars());
803 
804   // 3) Try to obtain a sample from the bounded set.
805   Optional<SmallVector<int64_t, 8>> boundedSample =
806       Simplex(boundedSet).findIntegerSample();
807   if (!boundedSample)
808     return {};
809   assert(boundedSet.containsPoint(*boundedSample) &&
810          "Simplex returned an invalid sample!");
811 
812   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
813   // full-dimensional cone, which necessarily contains an integer sample.
814   transformedSet.setAndEliminate(0, *boundedSample);
815   IntegerRelation &cone = transformedSet;
816 
817   // 5) Obtain an integer sample from the cone.
818   //
819   // We shrink the cone such that for any rational point in the shrunken cone,
820   // rounding up each of the point's coordinates produces a point that still
821   // lies in the original cone.
822   //
823   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
824   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
825   // shrunken cone will have the inequality tightened by some amount s, such
826   // that if x satisfies the shrunken cone's tightened inequality, then x + e
827   // satisfies the original inequality, i.e.,
828   //
829   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
830   //
831   // for any e_i values in [0, 1). In fact, we will handle the slightly more
832   // general case where e_i can be in [0, 1]. For example, consider the
833   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
834   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
835   // is minimized when we add 1 to the x_i with negative coefficient a_i and
836   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
837   // changing the value of the LHS by -3 + -7 = -10.
838   //
839   // In general, the value of the LHS can change by at most the sum of the
840   // negative a_i, so we accomodate this by shifting the inequality by this
841   // amount for the shrunken cone.
842   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
843     for (unsigned j = 0; j < cone.getNumVars(); ++j) {
844       int64_t coeff = cone.atIneq(i, j);
845       if (coeff < 0)
846         cone.atIneq(i, cone.getNumVars()) += coeff;
847     }
848   }
849 
850   // Obtain an integer sample in the cone by rounding up a rational point from
851   // the shrunken cone. Shrinking the cone amounts to shifting its apex
852   // "inwards" without changing its "shape"; the shrunken cone is still a
853   // full-dimensional cone and is hence non-empty.
854   Simplex shrunkenConeSimplex(cone);
855   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
856 
857   // The sample will always exist since the shrunken cone is non-empty.
858   SmallVector<Fraction, 8> shrunkenConeSample =
859       *shrunkenConeSimplex.getRationalSample();
860 
861   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
862 
863   // 6) Return transform * concat(boundedSample, coneSample).
864   SmallVector<int64_t, 8> &sample = *boundedSample;
865   sample.append(coneSample.begin(), coneSample.end());
866   return transform.postMultiplyWithColumn(sample);
867 }
868 
869 /// Helper to evaluate an affine expression at a point.
870 /// The expression is a list of coefficients for the dimensions followed by the
871 /// constant term.
valueAt(ArrayRef<int64_t> expr,ArrayRef<int64_t> point)872 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
873   assert(expr.size() == 1 + point.size() &&
874          "Dimensionalities of point and expression don't match!");
875   int64_t value = expr.back();
876   for (unsigned i = 0; i < point.size(); ++i)
877     value += expr[i] * point[i];
878   return value;
879 }
880 
881 /// A point satisfies an equality iff the value of the equality at the
882 /// expression is zero, and it satisfies an inequality iff the value of the
883 /// inequality at that point is non-negative.
containsPoint(ArrayRef<int64_t> point) const884 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
885   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
886     if (valueAt(getEquality(i), point) != 0)
887       return false;
888   }
889   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
890     if (valueAt(getInequality(i), point) < 0)
891       return false;
892   }
893   return true;
894 }
895 
896 /// Just substitute the values given and check if an integer sample exists for
897 /// the local vars.
898 ///
899 /// TODO: this could be made more efficient by handling divisions separately.
900 /// Instead of finding an integer sample over all the locals, we can first
901 /// compute the values of the locals that have division representations and
902 /// only use the integer emptiness check for the locals that don't have this.
903 /// Handling this correctly requires ordering the divs, though.
904 Optional<SmallVector<int64_t, 8>>
containsPointNoLocal(ArrayRef<int64_t> point) const905 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
906   assert(point.size() == getNumVars() - getNumLocalVars() &&
907          "Point should contain all vars except locals!");
908   assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() &&
909          "This function depends on locals being stored last!");
910   IntegerRelation copy = *this;
911   copy.setAndEliminate(0, point);
912   return copy.findIntegerSample();
913 }
914 
915 DivisionRepr
getLocalReprs(std::vector<MaybeLocalRepr> * repr) const916 IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> *repr) const {
917   SmallVector<bool, 8> foundRepr(getNumVars(), false);
918   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
919     foundRepr[i] = true;
920 
921   unsigned localOffset = getVarKindOffset(VarKind::Local);
922   DivisionRepr divs(getNumVars(), getNumLocalVars());
923   bool changed;
924   do {
925     // Each time changed is true, at end of this iteration, one or more local
926     // vars have been detected as floor divs.
927     changed = false;
928     for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) {
929       if (!foundRepr[i + localOffset]) {
930         MaybeLocalRepr res =
931             computeSingleVarRepr(*this, foundRepr, localOffset + i,
932                                  divs.getDividend(i), divs.getDenom(i));
933         if (!res) {
934           // No representation was found, so clear the representation and
935           // continue.
936           divs.clearRepr(i);
937           continue;
938         }
939         foundRepr[localOffset + i] = true;
940         if (repr)
941           (*repr)[i] = res;
942         changed = true;
943       }
944     }
945   } while (changed);
946 
947   return divs;
948 }
949 
950 /// Tightens inequalities given that we are dealing with integer spaces. This is
951 /// analogous to the GCD test but applied to inequalities. The constant term can
952 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
953 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
954 /// fast method - linear in the number of coefficients.
955 // Example on how this affects practical cases: consider the scenario:
956 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
957 // j >= 100 instead of the tighter (exact) j >= 128.
gcdTightenInequalities()958 void IntegerRelation::gcdTightenInequalities() {
959   unsigned numCols = getNumCols();
960   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
961     // Normalize the constraint and tighten the constant term by the GCD.
962     int64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
963     if (gcd > 1)
964       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
965   }
966 }
967 
968 // Eliminates all variable variables in column range [posStart, posLimit).
969 // Returns the number of variables eliminated.
gaussianEliminateVars(unsigned posStart,unsigned posLimit)970 unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
971                                                 unsigned posLimit) {
972   // Return if variable positions to eliminate are out of range.
973   assert(posLimit <= getNumVars());
974   assert(hasConsistentState());
975 
976   if (posStart >= posLimit)
977     return 0;
978 
979   gcdTightenInequalities();
980 
981   unsigned pivotCol = 0;
982   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
983     // Find a row which has a non-zero coefficient in column 'j'.
984     unsigned pivotRow;
985     if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
986       // No pivot row in equalities with non-zero at 'pivotCol'.
987       if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
988         // If inequalities are also non-zero in 'pivotCol', it can be
989         // eliminated.
990         continue;
991       }
992       break;
993     }
994 
995     // Eliminate variable at 'pivotCol' from each equality row.
996     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
997       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
998                               /*isEq=*/true);
999       equalities.normalizeRow(i);
1000     }
1001 
1002     // Eliminate variable at 'pivotCol' from each inequality row.
1003     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1004       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1005                               /*isEq=*/false);
1006       inequalities.normalizeRow(i);
1007     }
1008     removeEquality(pivotRow);
1009     gcdTightenInequalities();
1010   }
1011   // Update position limit based on number eliminated.
1012   posLimit = pivotCol;
1013   // Remove eliminated columns from all constraints.
1014   removeVarRange(posStart, posLimit);
1015   return posLimit - posStart;
1016 }
1017 
1018 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1019 // to check if a constraint is redundant.
removeRedundantInequalities()1020 void IntegerRelation::removeRedundantInequalities() {
1021   SmallVector<bool, 32> redun(getNumInequalities(), false);
1022   // To check if an inequality is redundant, we replace the inequality by its
1023   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1024   // system is empty. If it is, the inequality is redundant.
1025   IntegerRelation tmpCst(*this);
1026   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1027     // Change the inequality to its complement.
1028     tmpCst.inequalities.negateRow(r);
1029     --tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
1030     if (tmpCst.isEmpty()) {
1031       redun[r] = true;
1032       // Zero fill the redundant inequality.
1033       inequalities.fillRow(r, /*value=*/0);
1034       tmpCst.inequalities.fillRow(r, /*value=*/0);
1035     } else {
1036       // Reverse the change (to avoid recreating tmpCst each time).
1037       ++tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
1038       tmpCst.inequalities.negateRow(r);
1039     }
1040   }
1041 
1042   unsigned pos = 0;
1043   for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
1044     if (!redun[r])
1045       inequalities.copyRow(r, pos++);
1046   }
1047   inequalities.resizeVertically(pos);
1048 }
1049 
1050 // A more complex check to eliminate redundant inequalities and equalities. Uses
1051 // Simplex to check if a constraint is redundant.
removeRedundantConstraints()1052 void IntegerRelation::removeRedundantConstraints() {
1053   // First, we run gcdTightenInequalities. This allows us to catch some
1054   // constraints which are not redundant when considering rational solutions
1055   // but are redundant in terms of integer solutions.
1056   gcdTightenInequalities();
1057   Simplex simplex(*this);
1058   simplex.detectRedundant();
1059 
1060   unsigned pos = 0;
1061   unsigned numIneqs = getNumInequalities();
1062   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1063   // the first constraints added are the inequalities.
1064   for (unsigned r = 0; r < numIneqs; r++) {
1065     if (!simplex.isMarkedRedundant(r))
1066       inequalities.copyRow(r, pos++);
1067   }
1068   inequalities.resizeVertically(pos);
1069 
1070   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1071   // after the inequalities, a pair of constraints for each equality is added.
1072   // An equality is redundant if both the inequalities in its pair are
1073   // redundant.
1074   pos = 0;
1075   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1076     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1077           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1078       equalities.copyRow(r, pos++);
1079   }
1080   equalities.resizeVertically(pos);
1081 }
1082 
computeVolume() const1083 Optional<uint64_t> IntegerRelation::computeVolume() const {
1084   assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
1085 
1086   Simplex simplex(*this);
1087   // If the polytope is rationally empty, there are certainly no integer
1088   // points.
1089   if (simplex.isEmpty())
1090     return 0;
1091 
1092   // Just find the maximum and minimum integer value of each non-local var
1093   // separately, thus finding the number of integer values each such var can
1094   // take. Multiplying these together gives a valid overapproximation of the
1095   // number of integer points in the relation. The result this gives is
1096   // equivalent to projecting (rationally) the relation onto its non-local vars
1097   // and returning the number of integer points in a minimal axis-parallel
1098   // hyperrectangular overapproximation of that.
1099   //
1100   // We also handle the special case where one dimension is unbounded and
1101   // another dimension can take no integer values. In this case, the volume is
1102   // zero.
1103   //
1104   // If there is no such empty dimension, if any dimension is unbounded we
1105   // just return the result as unbounded.
1106   uint64_t count = 1;
1107   SmallVector<int64_t, 8> dim(getNumVars() + 1);
1108   bool hasUnboundedVar = false;
1109   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) {
1110     dim[i] = 1;
1111     MaybeOptimum<int64_t> min, max;
1112     std::tie(min, max) = simplex.computeIntegerBounds(dim);
1113     dim[i] = 0;
1114 
1115     assert((!min.isEmpty() && !max.isEmpty()) &&
1116            "Polytope should be rationally non-empty!");
1117 
1118     // One of the dimensions is unbounded. Note this fact. We will return
1119     // unbounded if none of the other dimensions makes the volume zero.
1120     if (min.isUnbounded() || max.isUnbounded()) {
1121       hasUnboundedVar = true;
1122       continue;
1123     }
1124 
1125     // In this case there are no valid integer points and the volume is
1126     // definitely zero.
1127     if (min.getBoundedOptimum() > max.getBoundedOptimum())
1128       return 0;
1129 
1130     count *= (*max - *min + 1);
1131   }
1132 
1133   if (count == 0)
1134     return 0;
1135   if (hasUnboundedVar)
1136     return {};
1137   return count;
1138 }
1139 
eliminateRedundantLocalVar(unsigned posA,unsigned posB)1140 void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
1141   assert(posA < getNumLocalVars() && "Invalid local var position");
1142   assert(posB < getNumLocalVars() && "Invalid local var position");
1143 
1144   unsigned localOffset = getVarKindOffset(VarKind::Local);
1145   posA += localOffset;
1146   posB += localOffset;
1147   inequalities.addToColumn(posB, posA, 1);
1148   equalities.addToColumn(posB, posA, 1);
1149   removeVar(posB);
1150 }
1151 
1152 /// Adds additional local ids to the sets such that they both have the union
1153 /// of the local ids in each set, without changing the set of points that
1154 /// lie in `this` and `other`.
1155 ///
1156 /// To detect local ids that always take the same value, each local id is
1157 /// represented as a floordiv with constant denominator in terms of other ids.
1158 /// After extracting these divisions, local ids in `other` with the same
1159 /// division representation as some other local id in any set are considered
1160 /// duplicate and are merged.
1161 ///
1162 /// It is possible that division representation for some local id cannot be
1163 /// obtained, and thus these local ids are not considered for detecting
1164 /// duplicates.
mergeLocalVars(IntegerRelation & other)1165 unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) {
1166   IntegerRelation &relA = *this;
1167   IntegerRelation &relB = other;
1168 
1169   unsigned oldALocals = relA.getNumLocalVars();
1170 
1171   // Merge function that merges the local variables in both sets by treating
1172   // them as the same variable.
1173   auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool {
1174     // We only merge from local at pos j to local at pos i, where j > i.
1175     if (i >= j)
1176       return false;
1177 
1178     // If i < oldALocals, we are trying to merge duplicate divs. Since we do not
1179     // want to merge duplicates in A, we ignore this call.
1180     if (j < oldALocals)
1181       return false;
1182 
1183     // Merge local at pos j into local at position i.
1184     relA.eliminateRedundantLocalVar(i, j);
1185     relB.eliminateRedundantLocalVar(i, j);
1186     return true;
1187   };
1188 
1189   presburger::mergeLocalVars(*this, other, merge);
1190 
1191   // Since we do not remove duplicate divisions in relA, this is guranteed to be
1192   // non-negative.
1193   return relA.getNumLocalVars() - oldALocals;
1194 }
1195 
hasOnlyDivLocals() const1196 bool IntegerRelation::hasOnlyDivLocals() const {
1197   return getLocalReprs().hasAllReprs();
1198 }
1199 
removeDuplicateDivs()1200 void IntegerRelation::removeDuplicateDivs() {
1201   DivisionRepr divs = getLocalReprs();
1202   auto merge = [this](unsigned i, unsigned j) -> bool {
1203     eliminateRedundantLocalVar(i, j);
1204     return true;
1205   };
1206   divs.removeDuplicateDivs(merge);
1207 }
1208 
1209 /// Removes local variables using equalities. Each equality is checked if it
1210 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1211 /// variable and `affine-expr` is an affine expression not containing `e`.
1212 /// If an equality satisfies this form, the local variable is replaced in
1213 /// each constraint and then removed. The equality used to replace this local
1214 /// variable is also removed.
removeRedundantLocalVars()1215 void IntegerRelation::removeRedundantLocalVars() {
1216   // Normalize the equality constraints to reduce coefficients of local
1217   // variables to 1 wherever possible.
1218   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1219     equalities.normalizeRow(i);
1220 
1221   while (true) {
1222     unsigned i, e, j, f;
1223     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1224       // Find a local variable to eliminate using ith equality.
1225       for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j)
1226         if (std::abs(atEq(i, j)) == 1)
1227           break;
1228 
1229       // Local variable can be eliminated using ith equality.
1230       if (j < f)
1231         break;
1232     }
1233 
1234     // No equality can be used to eliminate a local variable.
1235     if (i == e)
1236       break;
1237 
1238     // Use the ith equality to simplify other equalities. If any changes
1239     // are made to an equality constraint, it is normalized by GCD.
1240     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1241       if (atEq(k, j) != 0) {
1242         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1243         equalities.normalizeRow(k);
1244       }
1245     }
1246 
1247     // Use the ith equality to simplify inequalities.
1248     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1249       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1250 
1251     // Remove the ith equality and the found local variable.
1252     removeVar(j);
1253     removeEquality(i);
1254   }
1255 }
1256 
convertVarKind(VarKind srcKind,unsigned varStart,unsigned varLimit,VarKind dstKind,unsigned pos)1257 void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
1258                                      unsigned varLimit, VarKind dstKind,
1259                                      unsigned pos) {
1260   assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range");
1261 
1262   if (varStart >= varLimit)
1263     return;
1264 
1265   // Append new local variables corresponding to the dimensions to be converted.
1266   unsigned convertCount = varLimit - varStart;
1267   unsigned newVarsBegin = insertVar(dstKind, pos, convertCount);
1268 
1269   // Swap the new local variables with dimensions.
1270   //
1271   // Essentially, this moves the information corresponding to the specified ids
1272   // of kind `srcKind` to the `convertCount` newly created ids of kind
1273   // `dstKind`. In particular, this moves the columns in the constraint
1274   // matrices, and zeros out the initially occupied columns (because the newly
1275   // created ids we're swapping with were zero-initialized).
1276   unsigned offset = getVarKindOffset(srcKind);
1277   for (unsigned i = 0; i < convertCount; ++i)
1278     swapVar(offset + varStart + i, newVarsBegin + i);
1279 
1280   // Complete the move by deleting the initially occupied columns.
1281   removeVarRange(srcKind, varStart, varLimit);
1282 }
1283 
addBound(BoundType type,unsigned pos,int64_t value)1284 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
1285   assert(pos < getNumCols());
1286   if (type == BoundType::EQ) {
1287     unsigned row = equalities.appendExtraRow();
1288     equalities(row, pos) = 1;
1289     equalities(row, getNumCols() - 1) = -value;
1290   } else {
1291     unsigned row = inequalities.appendExtraRow();
1292     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
1293     inequalities(row, getNumCols() - 1) =
1294         type == BoundType::LB ? -value : value;
1295   }
1296 }
1297 
addBound(BoundType type,ArrayRef<int64_t> expr,int64_t value)1298 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
1299                                int64_t value) {
1300   assert(type != BoundType::EQ && "EQ not implemented");
1301   assert(expr.size() == getNumCols());
1302   unsigned row = inequalities.appendExtraRow();
1303   for (unsigned i = 0, e = expr.size(); i < e; ++i)
1304     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
1305   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
1306       type == BoundType::LB ? -value : value;
1307 }
1308 
1309 /// Adds a new local variable as the floordiv of an affine function of other
1310 /// variables, the coefficients of which are provided in 'dividend' and with
1311 /// respect to a positive constant 'divisor'. Two constraints are added to the
1312 /// system to capture equivalence with the floordiv.
1313 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)1314 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1315                                        int64_t divisor) {
1316   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1317   assert(divisor > 0 && "positive divisor expected");
1318 
1319   appendVar(VarKind::Local);
1320 
1321   SmallVector<int64_t, 8> dividendCopy(dividend.begin(), dividend.end());
1322   dividendCopy.insert(dividendCopy.end() - 1, 0);
1323   addInequality(
1324       getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
1325   addInequality(
1326       getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
1327 }
1328 
1329 /// Finds an equality that equates the specified variable to a constant.
1330 /// Returns the position of the equality row. If 'symbolic' is set to true,
1331 /// symbols are also treated like a constant, i.e., an affine function of the
1332 /// symbols is also treated like a constant. Returns -1 if such an equality
1333 /// could not be found.
findEqualityToConstant(const IntegerRelation & cst,unsigned pos,bool symbolic=false)1334 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
1335                                   bool symbolic = false) {
1336   assert(pos < cst.getNumVars() && "invalid position");
1337   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1338     int64_t v = cst.atEq(r, pos);
1339     if (v * v != 1)
1340       continue;
1341     unsigned c;
1342     unsigned f = symbolic ? cst.getNumDimVars() : cst.getNumVars();
1343     // This checks for zeros in all positions other than 'pos' in [0, f)
1344     for (c = 0; c < f; c++) {
1345       if (c == pos)
1346         continue;
1347       if (cst.atEq(r, c) != 0) {
1348         // Dependent on another variable.
1349         break;
1350       }
1351     }
1352     if (c == f)
1353       // Equality is free of other variables.
1354       return r;
1355   }
1356   return -1;
1357 }
1358 
constantFoldVar(unsigned pos)1359 LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
1360   assert(pos < getNumVars() && "invalid position");
1361   int rowIdx;
1362   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
1363     return failure();
1364 
1365   // atEq(rowIdx, pos) is either -1 or 1.
1366   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
1367   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
1368   setAndEliminate(pos, constVal);
1369   return success();
1370 }
1371 
constantFoldVarRange(unsigned pos,unsigned num)1372 void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
1373   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
1374     if (failed(constantFoldVar(t)))
1375       t++;
1376   }
1377 }
1378 
1379 /// Returns a non-negative constant bound on the extent (upper bound - lower
1380 /// bound) of the specified variable if it is found to be a constant; returns
1381 /// None if it's not a constant. This methods treats symbolic variables
1382 /// specially, i.e., it looks for constant differences between affine
1383 /// expressions involving only the symbolic variables. See comments at
1384 /// function definition for example. 'lb', if provided, is set to the lower
1385 /// bound associated with the constant difference. Note that 'lb' is purely
1386 /// symbolic and thus will contain the coefficients of the symbolic variables
1387 /// and the constant coefficient.
1388 //  Egs: 0 <= i <= 15, return 16.
1389 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
1390 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
1391 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
1392 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const1393 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
1394     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
1395     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
1396     unsigned *minUbPos) const {
1397   assert(pos < getNumDimVars() && "Invalid variable position");
1398 
1399   // Find an equality for 'pos'^th variable that equates it to some function
1400   // of the symbolic variables (+ constant).
1401   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
1402   if (eqPos != -1) {
1403     auto eq = getEquality(eqPos);
1404     // If the equality involves a local var, punt for now.
1405     // TODO: this can be handled in the future by using the explicit
1406     // representation of the local vars.
1407     if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
1408                      [](int64_t coeff) { return coeff == 0; }))
1409       return None;
1410 
1411     // This variable can only take a single value.
1412     if (lb) {
1413       // Set lb to that symbolic value.
1414       lb->resize(getNumSymbolVars() + 1);
1415       if (ub)
1416         ub->resize(getNumSymbolVars() + 1);
1417       for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
1418         int64_t v = atEq(eqPos, pos);
1419         // atEq(eqRow, pos) is either -1 or 1.
1420         assert(v * v == 1);
1421         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
1422                          : -atEq(eqPos, getNumDimVars() + c) / v;
1423         // Since this is an equality, ub = lb.
1424         if (ub)
1425           (*ub)[c] = (*lb)[c];
1426       }
1427       assert(boundFloorDivisor &&
1428              "both lb and divisor or none should be provided");
1429       *boundFloorDivisor = 1;
1430     }
1431     if (minLbPos)
1432       *minLbPos = eqPos;
1433     if (minUbPos)
1434       *minUbPos = eqPos;
1435     return 1;
1436   }
1437 
1438   // Check if the variable appears at all in any of the inequalities.
1439   unsigned r, e;
1440   for (r = 0, e = getNumInequalities(); r < e; r++) {
1441     if (atIneq(r, pos) != 0)
1442       break;
1443   }
1444   if (r == e)
1445     // If it doesn't, there isn't a bound on it.
1446     return None;
1447 
1448   // Positions of constraints that are lower/upper bounds on the variable.
1449   SmallVector<unsigned, 4> lbIndices, ubIndices;
1450 
1451   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
1452   // the bounds can only involve symbolic (and local) variables. Since the
1453   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1454   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1455   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
1456                                /*eqIndices=*/nullptr, /*offset=*/0,
1457                                /*num=*/getNumDimVars());
1458 
1459   Optional<int64_t> minDiff = None;
1460   unsigned minLbPosition = 0, minUbPosition = 0;
1461   for (auto ubPos : ubIndices) {
1462     for (auto lbPos : lbIndices) {
1463       // Look for a lower bound and an upper bound that only differ by a
1464       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
1465       // For example, if ii is the pos^th variable, we are looking for
1466       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
1467       // minimum among all such constant differences is kept since that's the
1468       // constant bounding the extent of the pos^th variable.
1469       unsigned j, e;
1470       for (j = 0, e = getNumCols() - 1; j < e; j++)
1471         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
1472           break;
1473         }
1474       if (j < getNumCols() - 1)
1475         continue;
1476       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
1477                                  atIneq(lbPos, getNumCols() - 1) + 1,
1478                              atIneq(lbPos, pos));
1479       // This bound is non-negative by definition.
1480       diff = std::max<int64_t>(diff, 0);
1481       if (minDiff == None || diff < minDiff) {
1482         minDiff = diff;
1483         minLbPosition = lbPos;
1484         minUbPosition = ubPos;
1485       }
1486     }
1487   }
1488   if (lb && minDiff) {
1489     // Set lb to the symbolic lower bound.
1490     lb->resize(getNumSymbolVars() + 1);
1491     if (ub)
1492       ub->resize(getNumSymbolVars() + 1);
1493     // The lower bound is the ceildiv of the lb constraint over the coefficient
1494     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
1495     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
1496     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
1497     *boundFloorDivisor = atIneq(minLbPosition, pos);
1498     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1499     for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
1500       (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
1501     }
1502     if (ub) {
1503       for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
1504         (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
1505     }
1506     // The lower bound leads to a ceildiv while the upper bound is a floordiv
1507     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
1508     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
1509     // the constant term for the lower bound.
1510     (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
1511   }
1512   if (minLbPos)
1513     *minLbPos = minLbPosition;
1514   if (minUbPos)
1515     *minUbPos = minUbPosition;
1516   return minDiff;
1517 }
1518 
1519 template <bool isLower>
1520 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)1521 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
1522   assert(pos < getNumVars() && "invalid position");
1523   // Project to 'pos'.
1524   projectOut(0, pos);
1525   projectOut(1, getNumVars() - 1);
1526   // Check if there's an equality equating the '0'^th variable to a constant.
1527   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
1528   if (eqRowIdx != -1)
1529     // atEq(rowIdx, 0) is either -1 or 1.
1530     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
1531 
1532   // Check if the variable appears at all in any of the inequalities.
1533   unsigned r, e;
1534   for (r = 0, e = getNumInequalities(); r < e; r++) {
1535     if (atIneq(r, 0) != 0)
1536       break;
1537   }
1538   if (r == e)
1539     // If it doesn't, there isn't a bound on it.
1540     return None;
1541 
1542   Optional<int64_t> minOrMaxConst = None;
1543 
1544   // Take the max across all const lower bounds (or min across all constant
1545   // upper bounds).
1546   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1547     if (isLower) {
1548       if (atIneq(r, 0) <= 0)
1549         // Not a lower bound.
1550         continue;
1551     } else if (atIneq(r, 0) >= 0) {
1552       // Not an upper bound.
1553       continue;
1554     }
1555     unsigned c, f;
1556     for (c = 0, f = getNumCols() - 1; c < f; c++)
1557       if (c != 0 && atIneq(r, c) != 0)
1558         break;
1559     if (c < getNumCols() - 1)
1560       // Not a constant bound.
1561       continue;
1562 
1563     int64_t boundConst =
1564         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
1565                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
1566     if (isLower) {
1567       if (minOrMaxConst == None || boundConst > minOrMaxConst)
1568         minOrMaxConst = boundConst;
1569     } else {
1570       if (minOrMaxConst == None || boundConst < minOrMaxConst)
1571         minOrMaxConst = boundConst;
1572     }
1573   }
1574   return minOrMaxConst;
1575 }
1576 
getConstantBound(BoundType type,unsigned pos) const1577 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
1578                                                     unsigned pos) const {
1579   if (type == BoundType::LB)
1580     return IntegerRelation(*this)
1581         .computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1582   if (type == BoundType::UB)
1583     return IntegerRelation(*this)
1584         .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1585 
1586   assert(type == BoundType::EQ && "expected EQ");
1587   Optional<int64_t> lb =
1588       IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
1589           pos);
1590   Optional<int64_t> ub =
1591       IntegerRelation(*this)
1592           .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1593   return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
1594 }
1595 
1596 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const1597 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
1598   assert(pos < getNumCols() - 1);
1599   // Check for two non-zero coefficients in the range [pos, pos + sum).
1600   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1601     unsigned sum = 0;
1602     for (unsigned c = pos; c < pos + num; c++) {
1603       if (atIneq(r, c) != 0)
1604         sum++;
1605     }
1606     if (sum > 1)
1607       return false;
1608   }
1609   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1610     unsigned sum = 0;
1611     for (unsigned c = pos; c < pos + num; c++) {
1612       if (atEq(r, c) != 0)
1613         sum++;
1614     }
1615     if (sum > 1)
1616       return false;
1617   }
1618   return true;
1619 }
1620 
1621 /// Removes duplicate constraints, trivially true constraints, and constraints
1622 /// that can be detected as redundant as a result of differing only in their
1623 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
1624 /// considered trivially true.
1625 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
1626 //  remove duplicates in place.
removeTrivialRedundancy()1627 void IntegerRelation::removeTrivialRedundancy() {
1628   gcdTightenInequalities();
1629   normalizeConstraintsByGCD();
1630 
1631   // A map used to detect redundancy stemming from constraints that only differ
1632   // in their constant term. The value stored is <row position, const term>
1633   // for a given row.
1634   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
1635       rowsWithoutConstTerm;
1636   // To unique rows.
1637   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
1638 
1639   // Check if constraint is of the form <non-negative-constant> >= 0.
1640   auto isTriviallyValid = [&](unsigned r) -> bool {
1641     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
1642       if (atIneq(r, c) != 0)
1643         return false;
1644     }
1645     return atIneq(r, getNumCols() - 1) >= 0;
1646   };
1647 
1648   // Detect and mark redundant constraints.
1649   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
1650   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1651     int64_t *rowStart = &inequalities(r, 0);
1652     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
1653     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
1654       redunIneq[r] = true;
1655       continue;
1656     }
1657 
1658     // Among constraints that only differ in the constant term part, mark
1659     // everything other than the one with the smallest constant term redundant.
1660     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
1661     // former two are redundant).
1662     int64_t constTerm = atIneq(r, getNumCols() - 1);
1663     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
1664     const auto &ret =
1665         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
1666     if (!ret.second) {
1667       // Check if the other constraint has a higher constant term.
1668       auto &val = ret.first->second;
1669       if (val.second > constTerm) {
1670         // The stored row is redundant. Mark it so, and update with this one.
1671         redunIneq[val.first] = true;
1672         val = {r, constTerm};
1673       } else {
1674         // The one stored makes this one redundant.
1675         redunIneq[r] = true;
1676       }
1677     }
1678   }
1679 
1680   // Scan to get rid of all rows marked redundant, in-place.
1681   unsigned pos = 0;
1682   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1683     if (!redunIneq[r])
1684       inequalities.copyRow(r, pos++);
1685 
1686   inequalities.resizeVertically(pos);
1687 
1688   // TODO: consider doing this for equalities as well, but probably not worth
1689   // the savings.
1690 }
1691 
1692 #undef DEBUG_TYPE
1693 #define DEBUG_TYPE "fm"
1694 
1695 /// Eliminates variable at the specified position using Fourier-Motzkin
1696 /// variable elimination. This technique is exact for rational spaces but
1697 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
1698 /// to a projection operation yielding the (convex) set of integer points
1699 /// contained in the rational shadow of the set. An emptiness test that relies
1700 /// on this method will guarantee emptiness, i.e., it disproves the existence of
1701 /// a solution if it says it's empty.
1702 /// If a non-null isResultIntegerExact is passed, it is set to true if the
1703 /// result is also integer exact. If it's set to false, the obtained solution
1704 /// *may* not be exact, i.e., it may contain integer points that do not have an
1705 /// integer pre-image in the original set.
1706 ///
1707 /// Eg:
1708 /// j >= 0, j <= i + 1
1709 /// i >= 0, i <= N + 1
1710 /// Eliminating i yields,
1711 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
1712 ///
1713 /// If darkShadow = true, this method computes the dark shadow on elimination;
1714 /// the dark shadow is a convex integer subset of the exact integer shadow. A
1715 /// non-empty dark shadow proves the existence of an integer solution. The
1716 /// elimination in such a case could however be an under-approximation, and thus
1717 /// should not be used for scanning sets or used by itself for dependence
1718 /// checking.
1719 ///
1720 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
1721 ///            ^
1722 ///            |
1723 ///            | * * * * o o
1724 ///         i  | * * o o o o
1725 ///            | o * * * * *
1726 ///            --------------->
1727 ///                 j ->
1728 ///
1729 /// Eliminating i from this system (projecting on the j dimension):
1730 /// rational shadow / integer light shadow:  1 <= j <= 6
1731 /// dark shadow:                             3 <= j <= 6
1732 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
1733 /// holes/splinters:                         j = 2
1734 ///
1735 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
1736 // TODO: a slight modification to yield dark shadow version of FM (tightened),
1737 // which can prove the existence of a solution if there is one.
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)1738 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
1739                                               bool *isResultIntegerExact) {
1740   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
1741   LLVM_DEBUG(dump());
1742   assert(pos < getNumVars() && "invalid position");
1743   assert(hasConsistentState());
1744 
1745   // Check if this variable can be eliminated through a substitution.
1746   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1747     if (atEq(r, pos) != 0) {
1748       // Use Gaussian elimination here (since we have an equality).
1749       LogicalResult ret = gaussianEliminateVar(pos);
1750       (void)ret;
1751       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
1752       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
1753       LLVM_DEBUG(dump());
1754       return;
1755     }
1756   }
1757 
1758   // A fast linear time tightening.
1759   gcdTightenInequalities();
1760 
1761   // Check if the variable appears at all in any of the inequalities.
1762   if (isColZero(pos)) {
1763     // If it doesn't appear, just remove the column and return.
1764     // TODO: refactor removeColumns to use it from here.
1765     removeVar(pos);
1766     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1767     LLVM_DEBUG(dump());
1768     return;
1769   }
1770 
1771   // Positions of constraints that are lower bounds on the variable.
1772   SmallVector<unsigned, 4> lbIndices;
1773   // Positions of constraints that are lower bounds on the variable.
1774   SmallVector<unsigned, 4> ubIndices;
1775   // Positions of constraints that do not involve the variable.
1776   std::vector<unsigned> nbIndices;
1777   nbIndices.reserve(getNumInequalities());
1778 
1779   // Gather all lower bounds and upper bounds of the variable. Since the
1780   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1781   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1782   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1783     if (atIneq(r, pos) == 0) {
1784       // Var does not appear in bound.
1785       nbIndices.push_back(r);
1786     } else if (atIneq(r, pos) >= 1) {
1787       // Lower bound.
1788       lbIndices.push_back(r);
1789     } else {
1790       // Upper bound.
1791       ubIndices.push_back(r);
1792     }
1793   }
1794 
1795   PresburgerSpace newSpace = getSpace();
1796   VarKind idKindRemove = newSpace.getVarKindAt(pos);
1797   unsigned relativePos = pos - newSpace.getVarKindOffset(idKindRemove);
1798   newSpace.removeVarRange(idKindRemove, relativePos, relativePos + 1);
1799 
1800   /// Create the new system which has one variable less.
1801   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
1802                          getNumEqualities(), getNumCols() - 1, newSpace);
1803 
1804   // This will be used to check if the elimination was integer exact.
1805   unsigned lcmProducts = 1;
1806 
1807   // Let x be the variable we are eliminating.
1808   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
1809   // that c_l, c_u >= 1) we have:
1810   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
1811   // We thus generate a constraint:
1812   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
1813   // Note if c_l = c_u = 1, all integer points captured by the resulting
1814   // constraint correspond to integer points in the original system (i.e., they
1815   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
1816   // integer exact.
1817   for (auto ubPos : ubIndices) {
1818     for (auto lbPos : lbIndices) {
1819       SmallVector<int64_t, 4> ineq;
1820       ineq.reserve(newRel.getNumCols());
1821       int64_t lbCoeff = atIneq(lbPos, pos);
1822       // Note that in the comments above, ubCoeff is the negation of the
1823       // coefficient in the canonical form as the view taken here is that of the
1824       // term being moved to the other size of '>='.
1825       int64_t ubCoeff = -atIneq(ubPos, pos);
1826       // TODO: refactor this loop to avoid all branches inside.
1827       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1828         if (l == pos)
1829           continue;
1830         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
1831         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
1832         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
1833                        atIneq(lbPos, l) * (lcm / lbCoeff));
1834         lcmProducts *= lcm;
1835       }
1836       if (darkShadow) {
1837         // The dark shadow is a convex subset of the exact integer shadow. If
1838         // there is a point here, it proves the existence of a solution.
1839         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
1840       }
1841       // TODO: we need to have a way to add inequalities in-place in
1842       // IntegerRelation instead of creating and copying over.
1843       newRel.addInequality(ineq);
1844     }
1845   }
1846 
1847   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
1848                           << "\n");
1849   if (lcmProducts == 1 && isResultIntegerExact)
1850     *isResultIntegerExact = true;
1851 
1852   // Copy over the constraints not involving this variable.
1853   for (auto nbPos : nbIndices) {
1854     SmallVector<int64_t, 4> ineq;
1855     ineq.reserve(getNumCols() - 1);
1856     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1857       if (l == pos)
1858         continue;
1859       ineq.push_back(atIneq(nbPos, l));
1860     }
1861     newRel.addInequality(ineq);
1862   }
1863 
1864   assert(newRel.getNumConstraints() ==
1865          lbIndices.size() * ubIndices.size() + nbIndices.size());
1866 
1867   // Copy over the equalities.
1868   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1869     SmallVector<int64_t, 4> eq;
1870     eq.reserve(newRel.getNumCols());
1871     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1872       if (l == pos)
1873         continue;
1874       eq.push_back(atEq(r, l));
1875     }
1876     newRel.addEquality(eq);
1877   }
1878 
1879   // GCD tightening and normalization allows detection of more trivially
1880   // redundant constraints.
1881   newRel.gcdTightenInequalities();
1882   newRel.normalizeConstraintsByGCD();
1883   newRel.removeTrivialRedundancy();
1884   clearAndCopyFrom(newRel);
1885   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1886   LLVM_DEBUG(dump());
1887 }
1888 
1889 #undef DEBUG_TYPE
1890 #define DEBUG_TYPE "presburger"
1891 
projectOut(unsigned pos,unsigned num)1892 void IntegerRelation::projectOut(unsigned pos, unsigned num) {
1893   if (num == 0)
1894     return;
1895 
1896   // 'pos' can be at most getNumCols() - 2 if num > 0.
1897   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
1898   assert(pos + num < getNumCols() && "invalid range");
1899 
1900   // Eliminate as many variables as possible using Gaussian elimination.
1901   unsigned currentPos = pos;
1902   unsigned numToEliminate = num;
1903   unsigned numGaussianEliminated = 0;
1904 
1905   while (currentPos < getNumVars()) {
1906     unsigned curNumEliminated =
1907         gaussianEliminateVars(currentPos, currentPos + numToEliminate);
1908     ++currentPos;
1909     numToEliminate -= curNumEliminated + 1;
1910     numGaussianEliminated += curNumEliminated;
1911   }
1912 
1913   // Eliminate the remaining using Fourier-Motzkin.
1914   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
1915     unsigned numToEliminate = num - numGaussianEliminated - i;
1916     fourierMotzkinEliminate(
1917         getBestVarToEliminate(*this, pos, pos + numToEliminate));
1918   }
1919 
1920   // Fast/trivial simplifications.
1921   gcdTightenInequalities();
1922   // Normalize constraints after tightening since the latter impacts this, but
1923   // not the other way round.
1924   normalizeConstraintsByGCD();
1925 }
1926 
1927 namespace {
1928 
1929 enum BoundCmpResult { Greater, Less, Equal, Unknown };
1930 
1931 /// Compares two affine bounds whose coefficients are provided in 'first' and
1932 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)1933 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1934   assert(a.size() == b.size());
1935 
1936   // For the bounds to be comparable, their corresponding variable
1937   // coefficients should be equal; the constant terms are then compared to
1938   // determine less/greater/equal.
1939 
1940   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
1941     return Unknown;
1942 
1943   if (a.back() == b.back())
1944     return Equal;
1945 
1946   return a.back() < b.back() ? Less : Greater;
1947 }
1948 } // namespace
1949 
1950 // Returns constraints that are common to both A & B.
getCommonConstraints(const IntegerRelation & a,const IntegerRelation & b,IntegerRelation & c)1951 static void getCommonConstraints(const IntegerRelation &a,
1952                                  const IntegerRelation &b, IntegerRelation &c) {
1953   c = IntegerRelation(a.getSpace());
1954   // a naive O(n^2) check should be enough here given the input sizes.
1955   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
1956     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
1957       if (a.getInequality(r) == b.getInequality(s)) {
1958         c.addInequality(a.getInequality(r));
1959         break;
1960       }
1961     }
1962   }
1963   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
1964     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
1965       if (a.getEquality(r) == b.getEquality(s)) {
1966         c.addEquality(a.getEquality(r));
1967         break;
1968       }
1969     }
1970   }
1971 }
1972 
1973 // Computes the bounding box with respect to 'other' by finding the min of the
1974 // lower bounds and the max of the upper bounds along each of the dimensions.
1975 LogicalResult
unionBoundingBox(const IntegerRelation & otherCst)1976 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
1977   assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
1978   assert(getNumLocalVars() == 0 && "local ids not supported yet here");
1979 
1980   // Get the constraints common to both systems; these will be added as is to
1981   // the union.
1982   IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
1983   getCommonConstraints(*this, otherCst, commonCst);
1984 
1985   std::vector<SmallVector<int64_t, 8>> boundingLbs;
1986   std::vector<SmallVector<int64_t, 8>> boundingUbs;
1987   boundingLbs.reserve(2 * getNumDimVars());
1988   boundingUbs.reserve(2 * getNumDimVars());
1989 
1990   // To hold lower and upper bounds for each dimension.
1991   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
1992   // To compute min of lower bounds and max of upper bounds for each dimension.
1993   SmallVector<int64_t, 4> minLb(getNumSymbolVars() + 1);
1994   SmallVector<int64_t, 4> maxUb(getNumSymbolVars() + 1);
1995   // To compute final new lower and upper bounds for the union.
1996   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
1997 
1998   int64_t lbFloorDivisor, otherLbFloorDivisor;
1999   for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
2000     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2001     if (!extent.has_value())
2002       // TODO: symbolic extents when necessary.
2003       // TODO: handle union if a dimension is unbounded.
2004       return failure();
2005 
2006     auto otherExtent = otherCst.getConstantBoundOnDimSize(
2007         d, &otherLb, &otherLbFloorDivisor, &otherUb);
2008     if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
2009       // TODO: symbolic extents when necessary.
2010       return failure();
2011 
2012     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2013 
2014     auto res = compareBounds(lb, otherLb);
2015     // Identify min.
2016     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2017       minLb = lb;
2018       // Since the divisor is for a floordiv, we need to convert to ceildiv,
2019       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2020       // div * i >= expr - div + 1.
2021       minLb.back() -= lbFloorDivisor - 1;
2022     } else if (res == BoundCmpResult::Greater) {
2023       minLb = otherLb;
2024       minLb.back() -= otherLbFloorDivisor - 1;
2025     } else {
2026       // Uncomparable - check for constant lower/upper bounds.
2027       auto constLb = getConstantBound(BoundType::LB, d);
2028       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
2029       if (!constLb.has_value() || !constOtherLb.has_value())
2030         return failure();
2031       std::fill(minLb.begin(), minLb.end(), 0);
2032       minLb.back() = std::min(constLb.value(), constOtherLb.value());
2033     }
2034 
2035     // Do the same for ub's but max of upper bounds. Identify max.
2036     auto uRes = compareBounds(ub, otherUb);
2037     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2038       maxUb = ub;
2039     } else if (uRes == BoundCmpResult::Less) {
2040       maxUb = otherUb;
2041     } else {
2042       // Uncomparable - check for constant lower/upper bounds.
2043       auto constUb = getConstantBound(BoundType::UB, d);
2044       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
2045       if (!constUb.has_value() || !constOtherUb.has_value())
2046         return failure();
2047       std::fill(maxUb.begin(), maxUb.end(), 0);
2048       maxUb.back() = std::max(constUb.value(), constOtherUb.value());
2049     }
2050 
2051     std::fill(newLb.begin(), newLb.end(), 0);
2052     std::fill(newUb.begin(), newUb.end(), 0);
2053 
2054     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2055     // and so it's the divisor for newLb and newUb as well.
2056     newLb[d] = lbFloorDivisor;
2057     newUb[d] = -lbFloorDivisor;
2058     // Copy over the symbolic part + constant term.
2059     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars());
2060     std::transform(newLb.begin() + getNumDimVars(), newLb.end(),
2061                    newLb.begin() + getNumDimVars(), std::negate<int64_t>());
2062     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars());
2063 
2064     boundingLbs.push_back(newLb);
2065     boundingUbs.push_back(newUb);
2066   }
2067 
2068   // Clear all constraints and add the lower/upper bounds for the bounding box.
2069   clearConstraints();
2070   for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
2071     addInequality(boundingLbs[d]);
2072     addInequality(boundingUbs[d]);
2073   }
2074 
2075   // Add the constraints that were common to both systems.
2076   append(commonCst);
2077   removeTrivialRedundancy();
2078 
2079   // TODO: copy over pure symbolic constraints from this and 'other' over to the
2080   // union (since the above are just the union along dimensions); we shouldn't
2081   // be discarding any other constraints on the symbols.
2082 
2083   return success();
2084 }
2085 
isColZero(unsigned pos) const2086 bool IntegerRelation::isColZero(unsigned pos) const {
2087   unsigned rowPos;
2088   return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
2089          !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
2090 }
2091 
2092 /// Find positions of inequalities and equalities that do not have a coefficient
2093 /// for [pos, pos + num) variables.
getIndependentConstraints(const IntegerRelation & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)2094 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
2095                                       unsigned num,
2096                                       SmallVectorImpl<unsigned> &nbIneqIndices,
2097                                       SmallVectorImpl<unsigned> &nbEqIndices) {
2098   assert(pos < cst.getNumVars() && "invalid start position");
2099   assert(pos + num <= cst.getNumVars() && "invalid limit");
2100 
2101   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2102     // The bounds are to be independent of [offset, offset + num) columns.
2103     unsigned c;
2104     for (c = pos; c < pos + num; ++c) {
2105       if (cst.atIneq(r, c) != 0)
2106         break;
2107     }
2108     if (c == pos + num)
2109       nbIneqIndices.push_back(r);
2110   }
2111 
2112   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2113     // The bounds are to be independent of [offset, offset + num) columns.
2114     unsigned c;
2115     for (c = pos; c < pos + num; ++c) {
2116       if (cst.atEq(r, c) != 0)
2117         break;
2118     }
2119     if (c == pos + num)
2120       nbEqIndices.push_back(r);
2121   }
2122 }
2123 
removeIndependentConstraints(unsigned pos,unsigned num)2124 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
2125   assert(pos + num <= getNumVars() && "invalid range");
2126 
2127   // Remove constraints that are independent of these variables.
2128   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2129   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2130 
2131   // Iterate in reverse so that indices don't have to be updated.
2132   // TODO: This method can be made more efficient (because removal of each
2133   // inequality leads to much shifting/copying in the underlying buffer).
2134   for (auto nbIndex : llvm::reverse(nbIneqIndices))
2135     removeInequality(nbIndex);
2136   for (auto nbIndex : llvm::reverse(nbEqIndices))
2137     removeEquality(nbIndex);
2138 }
2139 
getDomainSet() const2140 IntegerPolyhedron IntegerRelation::getDomainSet() const {
2141   IntegerRelation copyRel = *this;
2142 
2143   // Convert Range variables to Local variables.
2144   copyRel.convertVarKind(VarKind::Range, 0, getNumVarKind(VarKind::Range),
2145                          VarKind::Local);
2146 
2147   // Convert Domain variables to SetDim(Range) variables.
2148   copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
2149                          VarKind::SetDim);
2150 
2151   return IntegerPolyhedron(std::move(copyRel));
2152 }
2153 
getRangeSet() const2154 IntegerPolyhedron IntegerRelation::getRangeSet() const {
2155   IntegerRelation copyRel = *this;
2156 
2157   // Convert Domain variables to Local variables.
2158   copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
2159                          VarKind::Local);
2160 
2161   // We do not need to do anything to Range variables since they are already in
2162   // SetDim position.
2163 
2164   return IntegerPolyhedron(std::move(copyRel));
2165 }
2166 
intersectDomain(const IntegerPolyhedron & poly)2167 void IntegerRelation::intersectDomain(const IntegerPolyhedron &poly) {
2168   assert(getDomainSet().getSpace().isCompatible(poly.getSpace()) &&
2169          "Domain set is not compatible with poly");
2170 
2171   // Treating the poly as a relation, convert it from `0 -> R` to `R -> 0`.
2172   IntegerRelation rel = poly;
2173   rel.inverse();
2174 
2175   // Append dummy range variables to make the spaces compatible.
2176   rel.appendVar(VarKind::Range, getNumRangeVars());
2177 
2178   // Intersect in place.
2179   mergeLocalVars(rel);
2180   append(rel);
2181 }
2182 
intersectRange(const IntegerPolyhedron & poly)2183 void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) {
2184   assert(getRangeSet().getSpace().isCompatible(poly.getSpace()) &&
2185          "Range set is not compatible with poly");
2186 
2187   IntegerRelation rel = poly;
2188 
2189   // Append dummy domain variables to make the spaces compatible.
2190   rel.appendVar(VarKind::Domain, getNumDomainVars());
2191 
2192   mergeLocalVars(rel);
2193   append(rel);
2194 }
2195 
inverse()2196 void IntegerRelation::inverse() {
2197   unsigned numRangeVars = getNumVarKind(VarKind::Range);
2198   convertVarKind(VarKind::Domain, 0, getVarKindEnd(VarKind::Domain),
2199                  VarKind::Range);
2200   convertVarKind(VarKind::Range, 0, numRangeVars, VarKind::Domain);
2201 }
2202 
compose(const IntegerRelation & rel)2203 void IntegerRelation::compose(const IntegerRelation &rel) {
2204   assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) &&
2205          "Range of `this` should be compatible with Domain of `rel`");
2206 
2207   IntegerRelation copyRel = rel;
2208 
2209   // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C.
2210   // We convert R1 to A -> (B X C), and R2 to B X C then intersect the range of
2211   // R1 with R2. After this, we get R1: A -> C, by projecting out B.
2212   // TODO: Using nested spaces here would help, since we could directly
2213   // intersect the range with another relation.
2214   unsigned numBVars = getNumRangeVars();
2215 
2216   // Convert R1 from A -> B to A -> (B X C).
2217   appendVar(VarKind::Range, copyRel.getNumRangeVars());
2218 
2219   // Convert R2 to B X C.
2220   copyRel.convertVarKind(VarKind::Domain, 0, numBVars, VarKind::Range, 0);
2221 
2222   // Intersect R2 to range of R1.
2223   intersectRange(IntegerPolyhedron(copyRel));
2224 
2225   // Project out B in R1.
2226   convertVarKind(VarKind::Range, 0, numBVars, VarKind::Local);
2227 }
2228 
applyDomain(const IntegerRelation & rel)2229 void IntegerRelation::applyDomain(const IntegerRelation &rel) {
2230   inverse();
2231   compose(rel);
2232   inverse();
2233 }
2234 
applyRange(const IntegerRelation & rel)2235 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
2236 
printSpace(raw_ostream & os) const2237 void IntegerRelation::printSpace(raw_ostream &os) const {
2238   space.print(os);
2239   os << getNumConstraints() << " constraints\n";
2240 }
2241 
print(raw_ostream & os) const2242 void IntegerRelation::print(raw_ostream &os) const {
2243   assert(hasConsistentState());
2244   printSpace(os);
2245   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2246     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2247       os << atEq(i, j) << " ";
2248     }
2249     os << "= 0\n";
2250   }
2251   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2252     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2253       os << atIneq(i, j) << " ";
2254     }
2255     os << ">= 0\n";
2256   }
2257   os << '\n';
2258 }
2259 
dump() const2260 void IntegerRelation::dump() const { print(llvm::errs()); }
2261 
insertVar(VarKind kind,unsigned pos,unsigned num)2262 unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
2263                                       unsigned num) {
2264   assert((kind != VarKind::Domain || num == 0) &&
2265          "Domain has to be zero in a set");
2266   return IntegerRelation::insertVar(kind, pos, num);
2267 }
2268 IntegerPolyhedron
intersect(const IntegerPolyhedron & other) const2269 IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const {
2270   return IntegerPolyhedron(IntegerRelation::intersect(other));
2271 }
2272 
subtract(const PresburgerSet & other) const2273 PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const {
2274   return PresburgerSet(IntegerRelation::subtract(other));
2275 }
2276