1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A class to represent an relation over integer tuples. A relation is 10 // represented as a constraint system over a space of tuples of integer valued 11 // varaiables supporting symbolic identifiers and existential quantification. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/Presburger/IntegerRelation.h" 16 #include "mlir/Analysis/Presburger/LinearTransform.h" 17 #include "mlir/Analysis/Presburger/PWMAFunction.h" 18 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 19 #include "mlir/Analysis/Presburger/Simplex.h" 20 #include "mlir/Analysis/Presburger/Utils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/DenseSet.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "presburger" 26 27 using namespace mlir; 28 using namespace presburger; 29 30 using llvm::SmallDenseMap; 31 using llvm::SmallDenseSet; 32 33 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const { 34 return std::make_unique<IntegerRelation>(*this); 35 } 36 37 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const { 38 return std::make_unique<IntegerPolyhedron>(*this); 39 } 40 41 void IntegerRelation::append(const IntegerRelation &other) { 42 assert(isSpaceEqual(other) && "Spaces must be equal."); 43 44 inequalities.reserveRows(inequalities.getNumRows() + 45 other.getNumInequalities()); 46 equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities()); 47 48 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { 49 addInequality(other.getInequality(r)); 50 } 51 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { 52 addEquality(other.getEquality(r)); 53 } 54 } 55 56 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const { 57 IntegerRelation result = *this; 58 result.mergeLocalIds(other); 59 result.append(other); 60 return result; 61 } 62 63 bool IntegerRelation::isEqual(const IntegerRelation &other) const { 64 assert(isSpaceEqual(other) && "Spaces must be equal."); 65 return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); 66 } 67 68 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const { 69 assert(isSpaceEqual(other) && "Spaces must be equal."); 70 return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other)); 71 } 72 73 MaybeOptimum<SmallVector<Fraction, 8>> 74 IntegerRelation::findRationalLexMin() const { 75 assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); 76 MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin = 77 LexSimplex(*this).findRationalLexMin(); 78 79 if (!maybeLexMin.isBounded()) 80 return maybeLexMin; 81 82 // The Simplex returns the lexmin over all the variables including locals. But 83 // locals are not actually part of the space and should not be returned in the 84 // result. Since the locals are placed last in the list of identifiers, they 85 // will be minimized last in the lexmin. So simply truncating out the locals 86 // from the end of the answer gives the desired lexmin over the dimensions. 87 assert(maybeLexMin->size() == getNumIds() && 88 "Incorrect number of vars in lexMin!"); 89 maybeLexMin->resize(getNumDimAndSymbolIds()); 90 return maybeLexMin; 91 } 92 93 MaybeOptimum<SmallVector<int64_t, 8>> 94 IntegerRelation::findIntegerLexMin() const { 95 assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); 96 MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin = 97 LexSimplex(*this).findIntegerLexMin(); 98 99 if (!maybeLexMin.isBounded()) 100 return maybeLexMin.getKind(); 101 102 // The Simplex returns the lexmin over all the variables including locals. But 103 // locals are not actually part of the space and should not be returned in the 104 // result. Since the locals are placed last in the list of identifiers, they 105 // will be minimized last in the lexmin. So simply truncating out the locals 106 // from the end of the answer gives the desired lexmin over the dimensions. 107 assert(maybeLexMin->size() == getNumIds() && 108 "Incorrect number of vars in lexMin!"); 109 maybeLexMin->resize(getNumDimAndSymbolIds()); 110 return maybeLexMin; 111 } 112 113 static bool rangeIsZero(ArrayRef<int64_t> range) { 114 return llvm::all_of(range, [](int64_t x) { return x == 0; }); 115 } 116 117 void removeConstraintsInvolvingIdRange(IntegerRelation &poly, unsigned begin, 118 unsigned count) { 119 // We loop until i > 0 and index into i - 1 to avoid sign issues. 120 // 121 // We iterate backwards so that whether we remove constraint i - 1 or not, the 122 // next constraint to be tested is always i - 2. 123 for (unsigned i = poly.getNumEqualities(); i > 0; i--) 124 if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count))) 125 poly.removeEquality(i - 1); 126 for (unsigned i = poly.getNumInequalities(); i > 0; i--) 127 if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count))) 128 poly.removeInequality(i - 1); 129 } 130 131 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const { 132 return {PresburgerSpace(*this), getNumInequalities(), getNumEqualities()}; 133 } 134 135 void IntegerRelation::truncateIdKind(IdKind kind, 136 const CountsSnapshot &counts) { 137 truncateIdKind(kind, counts.getSpace().getNumIdKind(kind)); 138 } 139 140 void IntegerRelation::truncate(const CountsSnapshot &counts) { 141 truncateIdKind(IdKind::Domain, counts); 142 truncateIdKind(IdKind::Range, counts); 143 truncateIdKind(IdKind::Symbol, counts); 144 truncateIdKind(IdKind::Local, counts); 145 removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); 146 removeEqualityRange(counts.getNumEqs(), getNumEqualities()); 147 } 148 149 SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { 150 // Compute the symbolic lexmin of the dims and locals, with the symbols being 151 // the actual symbols of this set. 152 SymbolicLexMin result = 153 SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace( 154 /*numDims=*/getNumSymbolIds()))) 155 .computeSymbolicIntegerLexMin(); 156 157 // We want to return only the lexmin over the dims, so strip the locals from 158 // the computed lexmin. 159 result.lexmin.truncateOutput(result.lexmin.getNumOutputs() - 160 getNumLocalIds()); 161 return result; 162 } 163 164 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { 165 assert(pos <= getNumIdKind(kind)); 166 167 unsigned insertPos = PresburgerSpace::insertId(kind, pos, num); 168 inequalities.insertColumns(insertPos, num); 169 equalities.insertColumns(insertPos, num); 170 return insertPos; 171 } 172 173 unsigned IntegerRelation::appendId(IdKind kind, unsigned num) { 174 unsigned pos = getNumIdKind(kind); 175 return insertId(kind, pos, num); 176 } 177 178 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) { 179 assert(eq.size() == getNumCols()); 180 unsigned row = equalities.appendExtraRow(); 181 for (unsigned i = 0, e = eq.size(); i < e; ++i) 182 equalities(row, i) = eq[i]; 183 } 184 185 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) { 186 assert(inEq.size() == getNumCols()); 187 unsigned row = inequalities.appendExtraRow(); 188 for (unsigned i = 0, e = inEq.size(); i < e; ++i) 189 inequalities(row, i) = inEq[i]; 190 } 191 192 void IntegerRelation::removeId(IdKind kind, unsigned pos) { 193 removeIdRange(kind, pos, pos + 1); 194 } 195 196 void IntegerRelation::removeId(unsigned pos) { removeIdRange(pos, pos + 1); } 197 198 void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart, 199 unsigned idLimit) { 200 assert(idLimit <= getNumIdKind(kind)); 201 202 if (idStart >= idLimit) 203 return; 204 205 // Remove eliminated identifiers from the constraints. 206 unsigned offset = getIdKindOffset(kind); 207 equalities.removeColumns(offset + idStart, idLimit - idStart); 208 inequalities.removeColumns(offset + idStart, idLimit - idStart); 209 210 // Remove eliminated identifiers from the space. 211 PresburgerSpace::removeIdRange(kind, idStart, idLimit); 212 } 213 214 void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) { 215 assert(idLimit <= getNumIds()); 216 217 if (idStart >= idLimit) 218 return; 219 220 // Helper function to remove ids of the specified kind in the given range 221 // [start, limit), The range is absolute (i.e. it is not relative to the kind 222 // of identifier). Also updates `limit` to reflect the deleted identifiers. 223 auto removeIdKindInRange = [this](IdKind kind, unsigned &start, 224 unsigned &limit) { 225 if (start >= limit) 226 return; 227 228 unsigned offset = getIdKindOffset(kind); 229 unsigned num = getNumIdKind(kind); 230 231 // Get `start`, `limit` relative to the specified kind. 232 unsigned relativeStart = 233 start <= offset ? 0 : std::min(num, start - offset); 234 unsigned relativeLimit = 235 limit <= offset ? 0 : std::min(num, limit - offset); 236 237 // Remove ids of the specified kind in the relative range. 238 removeIdRange(kind, relativeStart, relativeLimit); 239 240 // Update `limit` to reflect deleted identifiers. 241 // `start` does not need to be updated because any identifiers that are 242 // deleted are after position `start`. 243 limit -= relativeLimit - relativeStart; 244 }; 245 246 removeIdKindInRange(IdKind::Domain, idStart, idLimit); 247 removeIdKindInRange(IdKind::Range, idStart, idLimit); 248 removeIdKindInRange(IdKind::Symbol, idStart, idLimit); 249 removeIdKindInRange(IdKind::Local, idStart, idLimit); 250 } 251 252 void IntegerRelation::removeEquality(unsigned pos) { 253 equalities.removeRow(pos); 254 } 255 256 void IntegerRelation::removeInequality(unsigned pos) { 257 inequalities.removeRow(pos); 258 } 259 260 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) { 261 if (start >= end) 262 return; 263 equalities.removeRows(start, end - start); 264 } 265 266 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) { 267 if (start >= end) 268 return; 269 inequalities.removeRows(start, end - start); 270 } 271 272 void IntegerRelation::swapId(unsigned posA, unsigned posB) { 273 assert(posA < getNumIds() && "invalid position A"); 274 assert(posB < getNumIds() && "invalid position B"); 275 276 if (posA == posB) 277 return; 278 279 inequalities.swapColumns(posA, posB); 280 equalities.swapColumns(posA, posB); 281 } 282 283 void IntegerRelation::clearConstraints() { 284 equalities.resizeVertically(0); 285 inequalities.resizeVertically(0); 286 } 287 288 /// Gather all lower and upper bounds of the identifier at `pos`, and 289 /// optionally any equalities on it. In addition, the bounds are to be 290 /// independent of identifiers in position range [`offset`, `offset` + `num`). 291 void IntegerRelation::getLowerAndUpperBoundIndices( 292 unsigned pos, SmallVectorImpl<unsigned> *lbIndices, 293 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices, 294 unsigned offset, unsigned num) const { 295 assert(pos < getNumIds() && "invalid position"); 296 assert(offset + num < getNumCols() && "invalid range"); 297 298 // Checks for a constraint that has a non-zero coeff for the identifiers in 299 // the position range [offset, offset + num) while ignoring `pos`. 300 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) { 301 unsigned c, f; 302 auto cst = isEq ? getEquality(r) : getInequality(r); 303 for (c = offset, f = offset + num; c < f; ++c) { 304 if (c == pos) 305 continue; 306 if (cst[c] != 0) 307 break; 308 } 309 return c < f; 310 }; 311 312 // Gather all lower bounds and upper bounds of the variable. Since the 313 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 314 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 315 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 316 // The bounds are to be independent of [offset, offset + num) columns. 317 if (containsConstraintDependentOnRange(r, /*isEq=*/false)) 318 continue; 319 if (atIneq(r, pos) >= 1) { 320 // Lower bound. 321 lbIndices->push_back(r); 322 } else if (atIneq(r, pos) <= -1) { 323 // Upper bound. 324 ubIndices->push_back(r); 325 } 326 } 327 328 // An equality is both a lower and upper bound. Record any equalities 329 // involving the pos^th identifier. 330 if (!eqIndices) 331 return; 332 333 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 334 if (atEq(r, pos) == 0) 335 continue; 336 if (containsConstraintDependentOnRange(r, /*isEq=*/true)) 337 continue; 338 eqIndices->push_back(r); 339 } 340 } 341 342 bool IntegerRelation::hasConsistentState() const { 343 if (!inequalities.hasConsistentState()) 344 return false; 345 if (!equalities.hasConsistentState()) 346 return false; 347 return true; 348 } 349 350 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) { 351 if (values.empty()) 352 return; 353 assert(pos + values.size() <= getNumIds() && 354 "invalid position or too many values"); 355 // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the 356 // constant term and removing the id x_j. We do this for all the ids 357 // pos, pos + 1, ... pos + values.size() - 1. 358 unsigned constantColPos = getNumCols() - 1; 359 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 360 inequalities.addToColumn(i + pos, constantColPos, values[i]); 361 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 362 equalities.addToColumn(i + pos, constantColPos, values[i]); 363 removeIdRange(pos, pos + values.size()); 364 } 365 366 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) { 367 *this = other; 368 } 369 370 // Searches for a constraint with a non-zero coefficient at `colIdx` in 371 // equality (isEq=true) or inequality (isEq=false) constraints. 372 // Returns true and sets row found in search in `rowIdx`, false otherwise. 373 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, 374 unsigned *rowIdx) const { 375 assert(colIdx < getNumCols() && "position out of bounds"); 376 auto at = [&](unsigned rowIdx) -> int64_t { 377 return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); 378 }; 379 unsigned e = isEq ? getNumEqualities() : getNumInequalities(); 380 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { 381 if (at(*rowIdx) != 0) { 382 return true; 383 } 384 } 385 return false; 386 } 387 388 void IntegerRelation::normalizeConstraintsByGCD() { 389 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 390 equalities.normalizeRow(i); 391 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 392 inequalities.normalizeRow(i); 393 } 394 395 bool IntegerRelation::hasInvalidConstraint() const { 396 assert(hasConsistentState()); 397 auto check = [&](bool isEq) -> bool { 398 unsigned numCols = getNumCols(); 399 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); 400 for (unsigned i = 0, e = numRows; i < e; ++i) { 401 unsigned j; 402 for (j = 0; j < numCols - 1; ++j) { 403 int64_t v = isEq ? atEq(i, j) : atIneq(i, j); 404 // Skip rows with non-zero variable coefficients. 405 if (v != 0) 406 break; 407 } 408 if (j < numCols - 1) { 409 continue; 410 } 411 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. 412 // Example invalid constraints include: '1 == 0' or '-1 >= 0' 413 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); 414 if ((isEq && v != 0) || (!isEq && v < 0)) { 415 return true; 416 } 417 } 418 return false; 419 }; 420 if (check(/*isEq=*/true)) 421 return true; 422 return check(/*isEq=*/false); 423 } 424 425 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at 426 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be 427 /// updated as they have already been eliminated. 428 static void eliminateFromConstraint(IntegerRelation *constraints, 429 unsigned rowIdx, unsigned pivotRow, 430 unsigned pivotCol, unsigned elimColStart, 431 bool isEq) { 432 // Skip if equality 'rowIdx' if same as 'pivotRow'. 433 if (isEq && rowIdx == pivotRow) 434 return; 435 auto at = [&](unsigned i, unsigned j) -> int64_t { 436 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); 437 }; 438 int64_t leadCoeff = at(rowIdx, pivotCol); 439 // Skip if leading coefficient at 'rowIdx' is already zero. 440 if (leadCoeff == 0) 441 return; 442 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); 443 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; 444 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); 445 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); 446 int64_t rowMultiplier = lcm / std::abs(leadCoeff); 447 448 unsigned numCols = constraints->getNumCols(); 449 for (unsigned j = 0; j < numCols; ++j) { 450 // Skip updating column 'j' if it was just eliminated. 451 if (j >= elimColStart && j < pivotCol) 452 continue; 453 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + 454 rowMultiplier * at(rowIdx, j); 455 isEq ? constraints->atEq(rowIdx, j) = v 456 : constraints->atIneq(rowIdx, j) = v; 457 } 458 } 459 460 /// Returns the position of the identifier that has the minimum <number of lower 461 /// bounds> times <number of upper bounds> from the specified range of 462 /// identifiers [start, end). It is often best to eliminate in the increasing 463 /// order of these counts when doing Fourier-Motzkin elimination since FM adds 464 /// that many new constraints. 465 static unsigned getBestIdToEliminate(const IntegerRelation &cst, unsigned start, 466 unsigned end) { 467 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); 468 469 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { 470 unsigned numLb = 0; 471 unsigned numUb = 0; 472 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 473 if (cst.atIneq(r, pos) > 0) { 474 ++numLb; 475 } else if (cst.atIneq(r, pos) < 0) { 476 ++numUb; 477 } 478 } 479 return numLb * numUb; 480 }; 481 482 unsigned minLoc = start; 483 unsigned min = getProductOfNumLowerUpperBounds(start); 484 for (unsigned c = start + 1; c < end; c++) { 485 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); 486 if (numLbUbProduct < min) { 487 min = numLbUbProduct; 488 minLoc = c; 489 } 490 } 491 return minLoc; 492 } 493 494 // Checks for emptiness of the set by eliminating identifiers successively and 495 // using the GCD test (on all equality constraints) and checking for trivially 496 // invalid constraints. Returns 'true' if the constraint system is found to be 497 // empty; false otherwise. 498 bool IntegerRelation::isEmpty() const { 499 if (isEmptyByGCDTest() || hasInvalidConstraint()) 500 return true; 501 502 IntegerRelation tmpCst(*this); 503 504 // First, eliminate as many local variables as possible using equalities. 505 tmpCst.removeRedundantLocalVars(); 506 if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) 507 return true; 508 509 // Eliminate as many identifiers as possible using Gaussian elimination. 510 unsigned currentPos = 0; 511 while (currentPos < tmpCst.getNumIds()) { 512 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); 513 ++currentPos; 514 // We check emptiness through trivial checks after eliminating each ID to 515 // detect emptiness early. Since the checks isEmptyByGCDTest() and 516 // hasInvalidConstraint() are linear time and single sweep on the constraint 517 // buffer, this appears reasonable - but can optimize in the future. 518 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) 519 return true; 520 } 521 522 // Eliminate the remaining using FM. 523 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { 524 tmpCst.fourierMotzkinEliminate( 525 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); 526 // Check for a constraint explosion. This rarely happens in practice, but 527 // this check exists as a safeguard against improperly constructed 528 // constraint systems or artificially created arbitrarily complex systems 529 // that aren't the intended use case for IntegerRelation. This is 530 // needed since FM has a worst case exponential complexity in theory. 531 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { 532 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n"); 533 return false; 534 } 535 536 // FM wouldn't have modified the equalities in any way. So no need to again 537 // run GCD test. Check for trivial invalid constraints. 538 if (tmpCst.hasInvalidConstraint()) 539 return true; 540 } 541 return false; 542 } 543 544 // Runs the GCD test on all equality constraints. Returns 'true' if this test 545 // fails on any equality. Returns 'false' otherwise. 546 // This test can be used to disprove the existence of a solution. If it returns 547 // true, no integer solution to the equality constraints can exist. 548 // 549 // GCD test definition: 550 // 551 // The equality constraint: 552 // 553 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 554 // 555 // has an integer solution iff: 556 // 557 // GCD of c_1, c_2, ..., c_n divides c_0. 558 // 559 bool IntegerRelation::isEmptyByGCDTest() const { 560 assert(hasConsistentState()); 561 unsigned numCols = getNumCols(); 562 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 563 uint64_t gcd = std::abs(atEq(i, 0)); 564 for (unsigned j = 1; j < numCols - 1; ++j) { 565 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); 566 } 567 int64_t v = std::abs(atEq(i, numCols - 1)); 568 if (gcd > 0 && (v % gcd != 0)) { 569 return true; 570 } 571 } 572 return false; 573 } 574 575 // Returns a matrix where each row is a vector along which the polytope is 576 // bounded. The span of the returned vectors is guaranteed to contain all 577 // such vectors. The returned vectors are NOT guaranteed to be linearly 578 // independent. This function should not be called on empty sets. 579 // 580 // It is sufficient to check the perpendiculars of the constraints, as the set 581 // of perpendiculars which are bounded must span all bounded directions. 582 Matrix IntegerRelation::getBoundedDirections() const { 583 // Note that it is necessary to add the equalities too (which the constructor 584 // does) even though we don't need to check if they are bounded; whether an 585 // inequality is bounded or not depends on what other constraints, including 586 // equalities, are present. 587 Simplex simplex(*this); 588 589 assert(!simplex.isEmpty() && "It is not meaningful to ask whether a " 590 "direction is bounded in an empty set."); 591 592 SmallVector<unsigned, 8> boundedIneqs; 593 // The constructor adds the inequalities to the simplex first, so this 594 // processes all the inequalities. 595 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 596 if (simplex.isBoundedAlongConstraint(i)) 597 boundedIneqs.push_back(i); 598 } 599 600 // The direction vector is given by the coefficients and does not include the 601 // constant term, so the matrix has one fewer column. 602 unsigned dirsNumCols = getNumCols() - 1; 603 Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols); 604 605 // Copy the bounded inequalities. 606 unsigned row = 0; 607 for (unsigned i : boundedIneqs) { 608 for (unsigned col = 0; col < dirsNumCols; ++col) 609 dirs(row, col) = atIneq(i, col); 610 ++row; 611 } 612 613 // Copy the equalities. All the equalities' perpendiculars are bounded. 614 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 615 for (unsigned col = 0; col < dirsNumCols; ++col) 616 dirs(row, col) = atEq(i, col); 617 ++row; 618 } 619 620 return dirs; 621 } 622 623 bool IntegerRelation::isIntegerEmpty() const { 624 return !findIntegerSample().hasValue(); 625 } 626 627 /// Let this set be S. If S is bounded then we directly call into the GBR 628 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e., 629 /// vectors v such that S extends to infinity along v or -v. In this case we 630 /// use an algorithm described in the integer set library (isl) manual and used 631 /// by the isl_set_sample function in that library. The algorithm is: 632 /// 633 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all 634 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the 635 /// dimensions. 636 /// 637 /// 2) Construct a set B by removing all constraints that involve 638 /// the unbounded dimensions and then deleting the unbounded dimensions. Note 639 /// that B is a Bounded set. 640 /// 641 /// 3) Try to obtain a sample from B using the GBR sampling 642 /// algorithm. If no sample is found, return that S is empty. 643 /// 644 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set 645 /// C. C is a full-dimensional Cone and always contains a sample. 646 /// 647 /// 5) Obtain an integer sample from C. 648 /// 649 /// 6) Return T*v, where v is the concatenation of the samples from B and C. 650 /// 651 /// The following is a sketch of a proof that 652 /// a) If the algorithm returns empty, then S is empty. 653 /// b) If the algorithm returns a sample, it is a valid sample in S. 654 /// 655 /// The algorithm returns empty only if B is empty, in which case S*T is 656 /// certainly empty since B was obtained by removing constraints and then 657 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector 658 /// v is in S*T iff T*v is in S. So in this case, since 659 /// S*T is empty, S is empty too. 660 /// 661 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the 662 /// constraints of S*T that did not involve unbounded dimensions are satisfied 663 /// by this substitution. All dimensions in the linear span of the dimensions 664 /// outside the prefix are unbounded in S*T (step 1). Substituting values for 665 /// the bounded dimensions cannot make these dimensions bounded, and these are 666 /// the only remaining dimensions in C, so C is unbounded along every vector (in 667 /// the positive or negative direction, or both). C is hence a full-dimensional 668 /// cone and therefore always contains an integer point. 669 /// 670 /// Concatenating the samples from B and C gives a sample v in S*T, so the 671 /// returned sample T*v is a sample in S. 672 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const { 673 // First, try the GCD test heuristic. 674 if (isEmptyByGCDTest()) 675 return {}; 676 677 Simplex simplex(*this); 678 if (simplex.isEmpty()) 679 return {}; 680 681 // For a bounded set, we directly call into the GBR sampling algorithm. 682 if (!simplex.isUnbounded()) 683 return simplex.findIntegerSample(); 684 685 // The set is unbounded. We cannot directly use the GBR algorithm. 686 // 687 // m is a matrix containing, in each row, a vector in which S is 688 // bounded, such that the linear span of all these dimensions contains all 689 // bounded dimensions in S. 690 Matrix m = getBoundedDirections(); 691 // In column echelon form, each row of m occupies only the first rank(m) 692 // columns and has zeros on the other columns. The transform T that brings S 693 // to column echelon form is unimodular as well, so this is a suitable 694 // transform to use in step 1 of the algorithm. 695 std::pair<unsigned, LinearTransform> result = 696 LinearTransform::makeTransformToColumnEchelon(std::move(m)); 697 const LinearTransform &transform = result.second; 698 // 1) Apply T to S to obtain S*T. 699 IntegerRelation transformedSet = transform.applyTo(*this); 700 701 // 2) Remove the unbounded dimensions and constraints involving them to 702 // obtain a bounded set. 703 IntegerRelation boundedSet(transformedSet); 704 unsigned numBoundedDims = result.first; 705 unsigned numUnboundedDims = getNumIds() - numBoundedDims; 706 removeConstraintsInvolvingIdRange(boundedSet, numBoundedDims, 707 numUnboundedDims); 708 boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds()); 709 710 // 3) Try to obtain a sample from the bounded set. 711 Optional<SmallVector<int64_t, 8>> boundedSample = 712 Simplex(boundedSet).findIntegerSample(); 713 if (!boundedSample) 714 return {}; 715 assert(boundedSet.containsPoint(*boundedSample) && 716 "Simplex returned an invalid sample!"); 717 718 // 4) Substitute the values of the bounded dimensions into S*T to obtain a 719 // full-dimensional cone, which necessarily contains an integer sample. 720 transformedSet.setAndEliminate(0, *boundedSample); 721 IntegerRelation &cone = transformedSet; 722 723 // 5) Obtain an integer sample from the cone. 724 // 725 // We shrink the cone such that for any rational point in the shrunken cone, 726 // rounding up each of the point's coordinates produces a point that still 727 // lies in the original cone. 728 // 729 // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i. 730 // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the 731 // shrunken cone will have the inequality tightened by some amount s, such 732 // that if x satisfies the shrunken cone's tightened inequality, then x + e 733 // satisfies the original inequality, i.e., 734 // 735 // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0 736 // 737 // for any e_i values in [0, 1). In fact, we will handle the slightly more 738 // general case where e_i can be in [0, 1]. For example, consider the 739 // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low 740 // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS 741 // is minimized when we add 1 to the x_i with negative coefficient a_i and 742 // keep the other x_i the same. In the example, we would get x = (3, 1, 1), 743 // changing the value of the LHS by -3 + -7 = -10. 744 // 745 // In general, the value of the LHS can change by at most the sum of the 746 // negative a_i, so we accomodate this by shifting the inequality by this 747 // amount for the shrunken cone. 748 for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) { 749 for (unsigned j = 0; j < cone.getNumIds(); ++j) { 750 int64_t coeff = cone.atIneq(i, j); 751 if (coeff < 0) 752 cone.atIneq(i, cone.getNumIds()) += coeff; 753 } 754 } 755 756 // Obtain an integer sample in the cone by rounding up a rational point from 757 // the shrunken cone. Shrinking the cone amounts to shifting its apex 758 // "inwards" without changing its "shape"; the shrunken cone is still a 759 // full-dimensional cone and is hence non-empty. 760 Simplex shrunkenConeSimplex(cone); 761 assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!"); 762 763 // The sample will always exist since the shrunken cone is non-empty. 764 SmallVector<Fraction, 8> shrunkenConeSample = 765 *shrunkenConeSimplex.getRationalSample(); 766 767 SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil)); 768 769 // 6) Return transform * concat(boundedSample, coneSample). 770 SmallVector<int64_t, 8> &sample = boundedSample.getValue(); 771 sample.append(coneSample.begin(), coneSample.end()); 772 return transform.postMultiplyWithColumn(sample); 773 } 774 775 /// Helper to evaluate an affine expression at a point. 776 /// The expression is a list of coefficients for the dimensions followed by the 777 /// constant term. 778 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) { 779 assert(expr.size() == 1 + point.size() && 780 "Dimensionalities of point and expression don't match!"); 781 int64_t value = expr.back(); 782 for (unsigned i = 0; i < point.size(); ++i) 783 value += expr[i] * point[i]; 784 return value; 785 } 786 787 /// A point satisfies an equality iff the value of the equality at the 788 /// expression is zero, and it satisfies an inequality iff the value of the 789 /// inequality at that point is non-negative. 790 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const { 791 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 792 if (valueAt(getEquality(i), point) != 0) 793 return false; 794 } 795 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 796 if (valueAt(getInequality(i), point) < 0) 797 return false; 798 } 799 return true; 800 } 801 802 /// Just substitute the values given and check if an integer sample exists for 803 /// the local ids. 804 /// 805 /// TODO: this could be made more efficient by handling divisions separately. 806 /// Instead of finding an integer sample over all the locals, we can first 807 /// compute the values of the locals that have division representations and 808 /// only use the integer emptiness check for the locals that don't have this. 809 /// Handling this correctly requires ordering the divs, though. 810 Optional<SmallVector<int64_t, 8>> 811 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const { 812 assert(point.size() == getNumIds() - getNumLocalIds() && 813 "Point should contain all ids except locals!"); 814 assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() && 815 "This function depends on locals being stored last!"); 816 IntegerRelation copy = *this; 817 copy.setAndEliminate(0, point); 818 return copy.findIntegerSample(); 819 } 820 821 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const { 822 std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds()); 823 SmallVector<unsigned, 4> denominators(getNumLocalIds()); 824 getLocalReprs(dividends, denominators, repr); 825 } 826 827 void IntegerRelation::getLocalReprs( 828 std::vector<SmallVector<int64_t, 8>> ÷nds, 829 SmallVector<unsigned, 4> &denominators) const { 830 std::vector<MaybeLocalRepr> repr(getNumLocalIds()); 831 getLocalReprs(dividends, denominators, repr); 832 } 833 834 void IntegerRelation::getLocalReprs( 835 std::vector<SmallVector<int64_t, 8>> ÷nds, 836 SmallVector<unsigned, 4> &denominators, 837 std::vector<MaybeLocalRepr> &repr) const { 838 839 repr.resize(getNumLocalIds()); 840 dividends.resize(getNumLocalIds()); 841 denominators.resize(getNumLocalIds()); 842 843 SmallVector<bool, 8> foundRepr(getNumIds(), false); 844 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) 845 foundRepr[i] = true; 846 847 unsigned divOffset = getNumDimAndSymbolIds(); 848 bool changed; 849 do { 850 // Each time changed is true, at end of this iteration, one or more local 851 // vars have been detected as floor divs. 852 changed = false; 853 for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { 854 if (!foundRepr[i + divOffset]) { 855 MaybeLocalRepr res = computeSingleVarRepr( 856 *this, foundRepr, divOffset + i, dividends[i], denominators[i]); 857 if (!res) 858 continue; 859 foundRepr[i + divOffset] = true; 860 repr[i] = res; 861 changed = true; 862 } 863 } 864 } while (changed); 865 866 // Set 0 denominator for identifiers for which no division representation 867 // could be found. 868 for (unsigned i = 0, e = repr.size(); i < e; ++i) 869 if (!repr[i]) 870 denominators[i] = 0; 871 } 872 873 /// Tightens inequalities given that we are dealing with integer spaces. This is 874 /// analogous to the GCD test but applied to inequalities. The constant term can 875 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., 876 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a 877 /// fast method - linear in the number of coefficients. 878 // Example on how this affects practical cases: consider the scenario: 879 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield 880 // j >= 100 instead of the tighter (exact) j >= 128. 881 void IntegerRelation::gcdTightenInequalities() { 882 unsigned numCols = getNumCols(); 883 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 884 // Normalize the constraint and tighten the constant term by the GCD. 885 uint64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1); 886 if (gcd > 1) 887 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd); 888 } 889 } 890 891 // Eliminates all identifier variables in column range [posStart, posLimit). 892 // Returns the number of variables eliminated. 893 unsigned IntegerRelation::gaussianEliminateIds(unsigned posStart, 894 unsigned posLimit) { 895 // Return if identifier positions to eliminate are out of range. 896 assert(posLimit <= getNumIds()); 897 assert(hasConsistentState()); 898 899 if (posStart >= posLimit) 900 return 0; 901 902 gcdTightenInequalities(); 903 904 unsigned pivotCol = 0; 905 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { 906 // Find a row which has a non-zero coefficient in column 'j'. 907 unsigned pivotRow; 908 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) { 909 // No pivot row in equalities with non-zero at 'pivotCol'. 910 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) { 911 // If inequalities are also non-zero in 'pivotCol', it can be 912 // eliminated. 913 continue; 914 } 915 break; 916 } 917 918 // Eliminate identifier at 'pivotCol' from each equality row. 919 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 920 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 921 /*isEq=*/true); 922 equalities.normalizeRow(i); 923 } 924 925 // Eliminate identifier at 'pivotCol' from each inequality row. 926 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 927 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 928 /*isEq=*/false); 929 inequalities.normalizeRow(i); 930 } 931 removeEquality(pivotRow); 932 gcdTightenInequalities(); 933 } 934 // Update position limit based on number eliminated. 935 posLimit = pivotCol; 936 // Remove eliminated columns from all constraints. 937 removeIdRange(posStart, posLimit); 938 return posLimit - posStart; 939 } 940 941 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin 942 // to check if a constraint is redundant. 943 void IntegerRelation::removeRedundantInequalities() { 944 SmallVector<bool, 32> redun(getNumInequalities(), false); 945 // To check if an inequality is redundant, we replace the inequality by its 946 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting 947 // system is empty. If it is, the inequality is redundant. 948 IntegerRelation tmpCst(*this); 949 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 950 // Change the inequality to its complement. 951 tmpCst.inequalities.negateRow(r); 952 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; 953 if (tmpCst.isEmpty()) { 954 redun[r] = true; 955 // Zero fill the redundant inequality. 956 inequalities.fillRow(r, /*value=*/0); 957 tmpCst.inequalities.fillRow(r, /*value=*/0); 958 } else { 959 // Reverse the change (to avoid recreating tmpCst each time). 960 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; 961 tmpCst.inequalities.negateRow(r); 962 } 963 } 964 965 unsigned pos = 0; 966 for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { 967 if (!redun[r]) 968 inequalities.copyRow(r, pos++); 969 } 970 inequalities.resizeVertically(pos); 971 } 972 973 // A more complex check to eliminate redundant inequalities and equalities. Uses 974 // Simplex to check if a constraint is redundant. 975 void IntegerRelation::removeRedundantConstraints() { 976 // First, we run gcdTightenInequalities. This allows us to catch some 977 // constraints which are not redundant when considering rational solutions 978 // but are redundant in terms of integer solutions. 979 gcdTightenInequalities(); 980 Simplex simplex(*this); 981 simplex.detectRedundant(); 982 983 unsigned pos = 0; 984 unsigned numIneqs = getNumInequalities(); 985 // Scan to get rid of all inequalities marked redundant, in-place. In Simplex, 986 // the first constraints added are the inequalities. 987 for (unsigned r = 0; r < numIneqs; r++) { 988 if (!simplex.isMarkedRedundant(r)) 989 inequalities.copyRow(r, pos++); 990 } 991 inequalities.resizeVertically(pos); 992 993 // Scan to get rid of all equalities marked redundant, in-place. In Simplex, 994 // after the inequalities, a pair of constraints for each equality is added. 995 // An equality is redundant if both the inequalities in its pair are 996 // redundant. 997 pos = 0; 998 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 999 if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) && 1000 simplex.isMarkedRedundant(numIneqs + 2 * r + 1))) 1001 equalities.copyRow(r, pos++); 1002 } 1003 equalities.resizeVertically(pos); 1004 } 1005 1006 Optional<uint64_t> IntegerRelation::computeVolume() const { 1007 assert(getNumSymbolIds() == 0 && "Symbols are not yet supported!"); 1008 1009 Simplex simplex(*this); 1010 // If the polytope is rationally empty, there are certainly no integer 1011 // points. 1012 if (simplex.isEmpty()) 1013 return 0; 1014 1015 // Just find the maximum and minimum integer value of each non-local id 1016 // separately, thus finding the number of integer values each such id can 1017 // take. Multiplying these together gives a valid overapproximation of the 1018 // number of integer points in the relation. The result this gives is 1019 // equivalent to projecting (rationally) the relation onto its non-local ids 1020 // and returning the number of integer points in a minimal axis-parallel 1021 // hyperrectangular overapproximation of that. 1022 // 1023 // We also handle the special case where one dimension is unbounded and 1024 // another dimension can take no integer values. In this case, the volume is 1025 // zero. 1026 // 1027 // If there is no such empty dimension, if any dimension is unbounded we 1028 // just return the result as unbounded. 1029 uint64_t count = 1; 1030 SmallVector<int64_t, 8> dim(getNumIds() + 1); 1031 bool hasUnboundedId = false; 1032 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) { 1033 dim[i] = 1; 1034 MaybeOptimum<int64_t> min, max; 1035 std::tie(min, max) = simplex.computeIntegerBounds(dim); 1036 dim[i] = 0; 1037 1038 assert((!min.isEmpty() && !max.isEmpty()) && 1039 "Polytope should be rationally non-empty!"); 1040 1041 // One of the dimensions is unbounded. Note this fact. We will return 1042 // unbounded if none of the other dimensions makes the volume zero. 1043 if (min.isUnbounded() || max.isUnbounded()) { 1044 hasUnboundedId = true; 1045 continue; 1046 } 1047 1048 // In this case there are no valid integer points and the volume is 1049 // definitely zero. 1050 if (min.getBoundedOptimum() > max.getBoundedOptimum()) 1051 return 0; 1052 1053 count *= (*max - *min + 1); 1054 } 1055 1056 if (count == 0) 1057 return 0; 1058 if (hasUnboundedId) 1059 return {}; 1060 return count; 1061 } 1062 1063 void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) { 1064 assert(posA < getNumLocalIds() && "Invalid local id position"); 1065 assert(posB < getNumLocalIds() && "Invalid local id position"); 1066 1067 unsigned localOffset = getIdKindOffset(IdKind::Local); 1068 posA += localOffset; 1069 posB += localOffset; 1070 inequalities.addToColumn(posB, posA, 1); 1071 equalities.addToColumn(posB, posA, 1); 1072 removeId(posB); 1073 } 1074 1075 /// Adds additional local ids to the sets such that they both have the union 1076 /// of the local ids in each set, without changing the set of points that 1077 /// lie in `this` and `other`. 1078 /// 1079 /// To detect local ids that always take the same in both sets, each local id is 1080 /// represented as a floordiv with constant denominator in terms of other ids. 1081 /// After extracting these divisions, local ids with the same division 1082 /// representation are considered duplicate and are merged. It is possible that 1083 /// division representation for some local id cannot be obtained, and thus these 1084 /// local ids are not considered for detecting duplicates. 1085 void IntegerRelation::mergeLocalIds(IntegerRelation &other) { 1086 assert(isSpaceCompatible(other) && "Spaces should be compatible."); 1087 1088 IntegerRelation &relA = *this; 1089 IntegerRelation &relB = other; 1090 1091 // Merge local ids of relA and relB without using division information, 1092 // i.e. append local ids of `relB` to `relA` and insert local ids of `relA` 1093 // to `relB` at start of its local ids. 1094 unsigned initLocals = relA.getNumLocalIds(); 1095 insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds()); 1096 relB.insertId(IdKind::Local, 0, initLocals); 1097 1098 // Get division representations from each rel. 1099 std::vector<SmallVector<int64_t, 8>> divsA, divsB; 1100 SmallVector<unsigned, 4> denomsA, denomsB; 1101 relA.getLocalReprs(divsA, denomsA); 1102 relB.getLocalReprs(divsB, denomsB); 1103 1104 // Copy division information for relB into `divsA` and `denomsA`, so that 1105 // these have the combined division information of both rels. Since newly 1106 // added local variables in relA and relB have no constraints, they will not 1107 // have any division representation. 1108 std::copy(divsB.begin() + initLocals, divsB.end(), 1109 divsA.begin() + initLocals); 1110 std::copy(denomsB.begin() + initLocals, denomsB.end(), 1111 denomsA.begin() + initLocals); 1112 1113 // Merge function that merges the local variables in both sets by treating 1114 // them as the same identifier. 1115 auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool { 1116 relA.eliminateRedundantLocalId(i, j); 1117 relB.eliminateRedundantLocalId(i, j); 1118 return true; 1119 }; 1120 1121 // Merge all divisions by removing duplicate divisions. 1122 unsigned localOffset = getIdKindOffset(IdKind::Local); 1123 presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); 1124 } 1125 1126 void IntegerRelation::removeDuplicateDivs() { 1127 std::vector<SmallVector<int64_t, 8>> divs; 1128 SmallVector<unsigned, 4> denoms; 1129 1130 getLocalReprs(divs, denoms); 1131 auto merge = [this](unsigned i, unsigned j) -> bool { 1132 eliminateRedundantLocalId(i, j); 1133 return true; 1134 }; 1135 presburger::removeDuplicateDivs(divs, denoms, getIdKindOffset(IdKind::Local), 1136 merge); 1137 } 1138 1139 /// Removes local variables using equalities. Each equality is checked if it 1140 /// can be reduced to the form: `e = affine-expr`, where `e` is a local 1141 /// variable and `affine-expr` is an affine expression not containing `e`. 1142 /// If an equality satisfies this form, the local variable is replaced in 1143 /// each constraint and then removed. The equality used to replace this local 1144 /// variable is also removed. 1145 void IntegerRelation::removeRedundantLocalVars() { 1146 // Normalize the equality constraints to reduce coefficients of local 1147 // variables to 1 wherever possible. 1148 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1149 equalities.normalizeRow(i); 1150 1151 while (true) { 1152 unsigned i, e, j, f; 1153 for (i = 0, e = getNumEqualities(); i < e; ++i) { 1154 // Find a local variable to eliminate using ith equality. 1155 for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) 1156 if (std::abs(atEq(i, j)) == 1) 1157 break; 1158 1159 // Local variable can be eliminated using ith equality. 1160 if (j < f) 1161 break; 1162 } 1163 1164 // No equality can be used to eliminate a local variable. 1165 if (i == e) 1166 break; 1167 1168 // Use the ith equality to simplify other equalities. If any changes 1169 // are made to an equality constraint, it is normalized by GCD. 1170 for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) { 1171 if (atEq(k, j) != 0) { 1172 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true); 1173 equalities.normalizeRow(k); 1174 } 1175 } 1176 1177 // Use the ith equality to simplify inequalities. 1178 for (unsigned k = 0, t = getNumInequalities(); k < t; ++k) 1179 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false); 1180 1181 // Remove the ith equality and the found local variable. 1182 removeId(j); 1183 removeEquality(i); 1184 } 1185 } 1186 1187 void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, 1188 unsigned idLimit, IdKind dstKind) { 1189 assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); 1190 1191 if (idStart >= idLimit) 1192 return; 1193 1194 // Append new local variables corresponding to the dimensions to be converted. 1195 unsigned newIdsBegin = getIdKindEnd(dstKind); 1196 unsigned convertCount = idLimit - idStart; 1197 appendId(dstKind, convertCount); 1198 1199 // Swap the new local variables with dimensions. 1200 // 1201 // Essentially, this moves the information corresponding to the specified ids 1202 // of kind `srcKind` to the `convertCount` newly created ids of kind 1203 // `dstKind`. In particular, this moves the columns in the constraint 1204 // matrices, and zeros out the initially occupied columns (because the newly 1205 // created ids we're swapping with were zero-initialized). 1206 unsigned offset = getIdKindOffset(srcKind); 1207 for (unsigned i = 0; i < convertCount; ++i) 1208 swapId(offset + idStart + i, newIdsBegin + i); 1209 1210 // Complete the move by deleting the initially occupied columns. 1211 removeIdRange(srcKind, idStart, idLimit); 1212 } 1213 1214 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { 1215 assert(pos < getNumCols()); 1216 if (type == BoundType::EQ) { 1217 unsigned row = equalities.appendExtraRow(); 1218 equalities(row, pos) = 1; 1219 equalities(row, getNumCols() - 1) = -value; 1220 } else { 1221 unsigned row = inequalities.appendExtraRow(); 1222 inequalities(row, pos) = type == BoundType::LB ? 1 : -1; 1223 inequalities(row, getNumCols() - 1) = 1224 type == BoundType::LB ? -value : value; 1225 } 1226 } 1227 1228 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr, 1229 int64_t value) { 1230 assert(type != BoundType::EQ && "EQ not implemented"); 1231 assert(expr.size() == getNumCols()); 1232 unsigned row = inequalities.appendExtraRow(); 1233 for (unsigned i = 0, e = expr.size(); i < e; ++i) 1234 inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i]; 1235 inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += 1236 type == BoundType::LB ? -value : value; 1237 } 1238 1239 /// Adds a new local identifier as the floordiv of an affine function of other 1240 /// identifiers, the coefficients of which are provided in 'dividend' and with 1241 /// respect to a positive constant 'divisor'. Two constraints are added to the 1242 /// system to capture equivalence with the floordiv. 1243 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. 1244 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend, 1245 int64_t divisor) { 1246 assert(dividend.size() == getNumCols() && "incorrect dividend size"); 1247 assert(divisor > 0 && "positive divisor expected"); 1248 1249 appendId(IdKind::Local); 1250 1251 // Add two constraints for this new identifier 'q'. 1252 SmallVector<int64_t, 8> bound(dividend.size() + 1); 1253 1254 // dividend - q * divisor >= 0 1255 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, 1256 bound.begin()); 1257 bound.back() = dividend.back(); 1258 bound[getNumIds() - 1] = -divisor; 1259 addInequality(bound); 1260 1261 // -dividend +qdivisor * q + divisor - 1 >= 0 1262 std::transform(bound.begin(), bound.end(), bound.begin(), 1263 std::negate<int64_t>()); 1264 bound[bound.size() - 1] += divisor - 1; 1265 addInequality(bound); 1266 } 1267 1268 /// Finds an equality that equates the specified identifier to a constant. 1269 /// Returns the position of the equality row. If 'symbolic' is set to true, 1270 /// symbols are also treated like a constant, i.e., an affine function of the 1271 /// symbols is also treated like a constant. Returns -1 if such an equality 1272 /// could not be found. 1273 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, 1274 bool symbolic = false) { 1275 assert(pos < cst.getNumIds() && "invalid position"); 1276 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1277 int64_t v = cst.atEq(r, pos); 1278 if (v * v != 1) 1279 continue; 1280 unsigned c; 1281 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); 1282 // This checks for zeros in all positions other than 'pos' in [0, f) 1283 for (c = 0; c < f; c++) { 1284 if (c == pos) 1285 continue; 1286 if (cst.atEq(r, c) != 0) { 1287 // Dependent on another identifier. 1288 break; 1289 } 1290 } 1291 if (c == f) 1292 // Equality is free of other identifiers. 1293 return r; 1294 } 1295 return -1; 1296 } 1297 1298 LogicalResult IntegerRelation::constantFoldId(unsigned pos) { 1299 assert(pos < getNumIds() && "invalid position"); 1300 int rowIdx; 1301 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) 1302 return failure(); 1303 1304 // atEq(rowIdx, pos) is either -1 or 1. 1305 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); 1306 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); 1307 setAndEliminate(pos, constVal); 1308 return success(); 1309 } 1310 1311 void IntegerRelation::constantFoldIdRange(unsigned pos, unsigned num) { 1312 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { 1313 if (failed(constantFoldId(t))) 1314 t++; 1315 } 1316 } 1317 1318 /// Returns a non-negative constant bound on the extent (upper bound - lower 1319 /// bound) of the specified identifier if it is found to be a constant; returns 1320 /// None if it's not a constant. This methods treats symbolic identifiers 1321 /// specially, i.e., it looks for constant differences between affine 1322 /// expressions involving only the symbolic identifiers. See comments at 1323 /// function definition for example. 'lb', if provided, is set to the lower 1324 /// bound associated with the constant difference. Note that 'lb' is purely 1325 /// symbolic and thus will contain the coefficients of the symbolic identifiers 1326 /// and the constant coefficient. 1327 // Egs: 0 <= i <= 15, return 16. 1328 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) 1329 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. 1330 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = 1331 // ceil(s0 - 7 / 8) = floor(s0 / 8)). 1332 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize( 1333 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor, 1334 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos, 1335 unsigned *minUbPos) const { 1336 assert(pos < getNumDimIds() && "Invalid identifier position"); 1337 1338 // Find an equality for 'pos'^th identifier that equates it to some function 1339 // of the symbolic identifiers (+ constant). 1340 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); 1341 if (eqPos != -1) { 1342 auto eq = getEquality(eqPos); 1343 // If the equality involves a local var, punt for now. 1344 // TODO: this can be handled in the future by using the explicit 1345 // representation of the local vars. 1346 if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1, 1347 [](int64_t coeff) { return coeff == 0; })) 1348 return None; 1349 1350 // This identifier can only take a single value. 1351 if (lb) { 1352 // Set lb to that symbolic value. 1353 lb->resize(getNumSymbolIds() + 1); 1354 if (ub) 1355 ub->resize(getNumSymbolIds() + 1); 1356 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { 1357 int64_t v = atEq(eqPos, pos); 1358 // atEq(eqRow, pos) is either -1 or 1. 1359 assert(v * v == 1); 1360 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v 1361 : -atEq(eqPos, getNumDimIds() + c) / v; 1362 // Since this is an equality, ub = lb. 1363 if (ub) 1364 (*ub)[c] = (*lb)[c]; 1365 } 1366 assert(boundFloorDivisor && 1367 "both lb and divisor or none should be provided"); 1368 *boundFloorDivisor = 1; 1369 } 1370 if (minLbPos) 1371 *minLbPos = eqPos; 1372 if (minUbPos) 1373 *minUbPos = eqPos; 1374 return 1; 1375 } 1376 1377 // Check if the identifier appears at all in any of the inequalities. 1378 unsigned r, e; 1379 for (r = 0, e = getNumInequalities(); r < e; r++) { 1380 if (atIneq(r, pos) != 0) 1381 break; 1382 } 1383 if (r == e) 1384 // If it doesn't, there isn't a bound on it. 1385 return None; 1386 1387 // Positions of constraints that are lower/upper bounds on the variable. 1388 SmallVector<unsigned, 4> lbIndices, ubIndices; 1389 1390 // Gather all symbolic lower bounds and upper bounds of the variable, i.e., 1391 // the bounds can only involve symbolic (and local) identifiers. Since the 1392 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1393 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1394 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, 1395 /*eqIndices=*/nullptr, /*offset=*/0, 1396 /*num=*/getNumDimIds()); 1397 1398 Optional<int64_t> minDiff = None; 1399 unsigned minLbPosition = 0, minUbPosition = 0; 1400 for (auto ubPos : ubIndices) { 1401 for (auto lbPos : lbIndices) { 1402 // Look for a lower bound and an upper bound that only differ by a 1403 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. 1404 // For example, if ii is the pos^th variable, we are looking for 1405 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The 1406 // minimum among all such constant differences is kept since that's the 1407 // constant bounding the extent of the pos^th variable. 1408 unsigned j, e; 1409 for (j = 0, e = getNumCols() - 1; j < e; j++) 1410 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { 1411 break; 1412 } 1413 if (j < getNumCols() - 1) 1414 continue; 1415 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + 1416 atIneq(lbPos, getNumCols() - 1) + 1, 1417 atIneq(lbPos, pos)); 1418 // This bound is non-negative by definition. 1419 diff = std::max<int64_t>(diff, 0); 1420 if (minDiff == None || diff < minDiff) { 1421 minDiff = diff; 1422 minLbPosition = lbPos; 1423 minUbPosition = ubPos; 1424 } 1425 } 1426 } 1427 if (lb && minDiff.hasValue()) { 1428 // Set lb to the symbolic lower bound. 1429 lb->resize(getNumSymbolIds() + 1); 1430 if (ub) 1431 ub->resize(getNumSymbolIds() + 1); 1432 // The lower bound is the ceildiv of the lb constraint over the coefficient 1433 // of the variable at 'pos'. We express the ceildiv equivalently as a floor 1434 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + 1435 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). 1436 *boundFloorDivisor = atIneq(minLbPosition, pos); 1437 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos)); 1438 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { 1439 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); 1440 } 1441 if (ub) { 1442 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) 1443 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c); 1444 } 1445 // The lower bound leads to a ceildiv while the upper bound is a floordiv 1446 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val + 1447 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to 1448 // the constant term for the lower bound. 1449 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1; 1450 } 1451 if (minLbPos) 1452 *minLbPos = minLbPosition; 1453 if (minUbPos) 1454 *minUbPos = minUbPosition; 1455 return minDiff; 1456 } 1457 1458 template <bool isLower> 1459 Optional<int64_t> 1460 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { 1461 assert(pos < getNumIds() && "invalid position"); 1462 // Project to 'pos'. 1463 projectOut(0, pos); 1464 projectOut(1, getNumIds() - 1); 1465 // Check if there's an equality equating the '0'^th identifier to a constant. 1466 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); 1467 if (eqRowIdx != -1) 1468 // atEq(rowIdx, 0) is either -1 or 1. 1469 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); 1470 1471 // Check if the identifier appears at all in any of the inequalities. 1472 unsigned r, e; 1473 for (r = 0, e = getNumInequalities(); r < e; r++) { 1474 if (atIneq(r, 0) != 0) 1475 break; 1476 } 1477 if (r == e) 1478 // If it doesn't, there isn't a bound on it. 1479 return None; 1480 1481 Optional<int64_t> minOrMaxConst = None; 1482 1483 // Take the max across all const lower bounds (or min across all constant 1484 // upper bounds). 1485 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1486 if (isLower) { 1487 if (atIneq(r, 0) <= 0) 1488 // Not a lower bound. 1489 continue; 1490 } else if (atIneq(r, 0) >= 0) { 1491 // Not an upper bound. 1492 continue; 1493 } 1494 unsigned c, f; 1495 for (c = 0, f = getNumCols() - 1; c < f; c++) 1496 if (c != 0 && atIneq(r, c) != 0) 1497 break; 1498 if (c < getNumCols() - 1) 1499 // Not a constant bound. 1500 continue; 1501 1502 int64_t boundConst = 1503 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) 1504 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); 1505 if (isLower) { 1506 if (minOrMaxConst == None || boundConst > minOrMaxConst) 1507 minOrMaxConst = boundConst; 1508 } else { 1509 if (minOrMaxConst == None || boundConst < minOrMaxConst) 1510 minOrMaxConst = boundConst; 1511 } 1512 } 1513 return minOrMaxConst; 1514 } 1515 1516 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type, 1517 unsigned pos) const { 1518 if (type == BoundType::LB) 1519 return IntegerRelation(*this) 1520 .computeConstantLowerOrUpperBound</*isLower=*/true>(pos); 1521 if (type == BoundType::UB) 1522 return IntegerRelation(*this) 1523 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1524 1525 assert(type == BoundType::EQ && "expected EQ"); 1526 Optional<int64_t> lb = 1527 IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>( 1528 pos); 1529 Optional<int64_t> ub = 1530 IntegerRelation(*this) 1531 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1532 return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None; 1533 } 1534 1535 // A simple (naive and conservative) check for hyper-rectangularity. 1536 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const { 1537 assert(pos < getNumCols() - 1); 1538 // Check for two non-zero coefficients in the range [pos, pos + sum). 1539 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1540 unsigned sum = 0; 1541 for (unsigned c = pos; c < pos + num; c++) { 1542 if (atIneq(r, c) != 0) 1543 sum++; 1544 } 1545 if (sum > 1) 1546 return false; 1547 } 1548 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1549 unsigned sum = 0; 1550 for (unsigned c = pos; c < pos + num; c++) { 1551 if (atEq(r, c) != 0) 1552 sum++; 1553 } 1554 if (sum > 1) 1555 return false; 1556 } 1557 return true; 1558 } 1559 1560 /// Removes duplicate constraints, trivially true constraints, and constraints 1561 /// that can be detected as redundant as a result of differing only in their 1562 /// constant term part. A constraint of the form <non-negative constant> >= 0 is 1563 /// considered trivially true. 1564 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to 1565 // remove duplicates in place. 1566 void IntegerRelation::removeTrivialRedundancy() { 1567 gcdTightenInequalities(); 1568 normalizeConstraintsByGCD(); 1569 1570 // A map used to detect redundancy stemming from constraints that only differ 1571 // in their constant term. The value stored is <row position, const term> 1572 // for a given row. 1573 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>> 1574 rowsWithoutConstTerm; 1575 // To unique rows. 1576 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet; 1577 1578 // Check if constraint is of the form <non-negative-constant> >= 0. 1579 auto isTriviallyValid = [&](unsigned r) -> bool { 1580 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { 1581 if (atIneq(r, c) != 0) 1582 return false; 1583 } 1584 return atIneq(r, getNumCols() - 1) >= 0; 1585 }; 1586 1587 // Detect and mark redundant constraints. 1588 SmallVector<bool, 256> redunIneq(getNumInequalities(), false); 1589 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1590 int64_t *rowStart = &inequalities(r, 0); 1591 auto row = ArrayRef<int64_t>(rowStart, getNumCols()); 1592 if (isTriviallyValid(r) || !rowSet.insert(row).second) { 1593 redunIneq[r] = true; 1594 continue; 1595 } 1596 1597 // Among constraints that only differ in the constant term part, mark 1598 // everything other than the one with the smallest constant term redundant. 1599 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the 1600 // former two are redundant). 1601 int64_t constTerm = atIneq(r, getNumCols() - 1); 1602 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1); 1603 const auto &ret = 1604 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); 1605 if (!ret.second) { 1606 // Check if the other constraint has a higher constant term. 1607 auto &val = ret.first->second; 1608 if (val.second > constTerm) { 1609 // The stored row is redundant. Mark it so, and update with this one. 1610 redunIneq[val.first] = true; 1611 val = {r, constTerm}; 1612 } else { 1613 // The one stored makes this one redundant. 1614 redunIneq[r] = true; 1615 } 1616 } 1617 } 1618 1619 // Scan to get rid of all rows marked redundant, in-place. 1620 unsigned pos = 0; 1621 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) 1622 if (!redunIneq[r]) 1623 inequalities.copyRow(r, pos++); 1624 1625 inequalities.resizeVertically(pos); 1626 1627 // TODO: consider doing this for equalities as well, but probably not worth 1628 // the savings. 1629 } 1630 1631 #undef DEBUG_TYPE 1632 #define DEBUG_TYPE "fm" 1633 1634 /// Eliminates identifier at the specified position using Fourier-Motzkin 1635 /// variable elimination. This technique is exact for rational spaces but 1636 /// conservative (in "rare" cases) for integer spaces. The operation corresponds 1637 /// to a projection operation yielding the (convex) set of integer points 1638 /// contained in the rational shadow of the set. An emptiness test that relies 1639 /// on this method will guarantee emptiness, i.e., it disproves the existence of 1640 /// a solution if it says it's empty. 1641 /// If a non-null isResultIntegerExact is passed, it is set to true if the 1642 /// result is also integer exact. If it's set to false, the obtained solution 1643 /// *may* not be exact, i.e., it may contain integer points that do not have an 1644 /// integer pre-image in the original set. 1645 /// 1646 /// Eg: 1647 /// j >= 0, j <= i + 1 1648 /// i >= 0, i <= N + 1 1649 /// Eliminating i yields, 1650 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1 1651 /// 1652 /// If darkShadow = true, this method computes the dark shadow on elimination; 1653 /// the dark shadow is a convex integer subset of the exact integer shadow. A 1654 /// non-empty dark shadow proves the existence of an integer solution. The 1655 /// elimination in such a case could however be an under-approximation, and thus 1656 /// should not be used for scanning sets or used by itself for dependence 1657 /// checking. 1658 /// 1659 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. 1660 /// ^ 1661 /// | 1662 /// | * * * * o o 1663 /// i | * * o o o o 1664 /// | o * * * * * 1665 /// ---------------> 1666 /// j -> 1667 /// 1668 /// Eliminating i from this system (projecting on the j dimension): 1669 /// rational shadow / integer light shadow: 1 <= j <= 6 1670 /// dark shadow: 3 <= j <= 6 1671 /// exact integer shadow: j = 1 \union 3 <= j <= 6 1672 /// holes/splinters: j = 2 1673 /// 1674 /// darkShadow = false, isResultIntegerExact = nullptr are default values. 1675 // TODO: a slight modification to yield dark shadow version of FM (tightened), 1676 // which can prove the existence of a solution if there is one. 1677 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, 1678 bool *isResultIntegerExact) { 1679 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); 1680 LLVM_DEBUG(dump()); 1681 assert(pos < getNumIds() && "invalid position"); 1682 assert(hasConsistentState()); 1683 1684 // Check if this identifier can be eliminated through a substitution. 1685 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1686 if (atEq(r, pos) != 0) { 1687 // Use Gaussian elimination here (since we have an equality). 1688 LogicalResult ret = gaussianEliminateId(pos); 1689 (void)ret; 1690 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); 1691 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); 1692 LLVM_DEBUG(dump()); 1693 return; 1694 } 1695 } 1696 1697 // A fast linear time tightening. 1698 gcdTightenInequalities(); 1699 1700 // Check if the identifier appears at all in any of the inequalities. 1701 if (isColZero(pos)) { 1702 // If it doesn't appear, just remove the column and return. 1703 // TODO: refactor removeColumns to use it from here. 1704 removeId(pos); 1705 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1706 LLVM_DEBUG(dump()); 1707 return; 1708 } 1709 1710 // Positions of constraints that are lower bounds on the variable. 1711 SmallVector<unsigned, 4> lbIndices; 1712 // Positions of constraints that are lower bounds on the variable. 1713 SmallVector<unsigned, 4> ubIndices; 1714 // Positions of constraints that do not involve the variable. 1715 std::vector<unsigned> nbIndices; 1716 nbIndices.reserve(getNumInequalities()); 1717 1718 // Gather all lower bounds and upper bounds of the variable. Since the 1719 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1720 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1721 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1722 if (atIneq(r, pos) == 0) { 1723 // Id does not appear in bound. 1724 nbIndices.push_back(r); 1725 } else if (atIneq(r, pos) >= 1) { 1726 // Lower bound. 1727 lbIndices.push_back(r); 1728 } else { 1729 // Upper bound. 1730 ubIndices.push_back(r); 1731 } 1732 } 1733 1734 PresburgerSpace newSpace = getSpace(); 1735 IdKind idKindRemove = newSpace.getIdKindAt(pos); 1736 unsigned relativePos = pos - newSpace.getIdKindOffset(idKindRemove); 1737 newSpace.removeIdRange(idKindRemove, relativePos, relativePos + 1); 1738 1739 /// Create the new system which has one identifier less. 1740 IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(), 1741 getNumEqualities(), getNumCols() - 1, newSpace); 1742 1743 // This will be used to check if the elimination was integer exact. 1744 unsigned lcmProducts = 1; 1745 1746 // Let x be the variable we are eliminating. 1747 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note 1748 // that c_l, c_u >= 1) we have: 1749 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u 1750 // We thus generate a constraint: 1751 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. 1752 // Note if c_l = c_u = 1, all integer points captured by the resulting 1753 // constraint correspond to integer points in the original system (i.e., they 1754 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is 1755 // integer exact. 1756 for (auto ubPos : ubIndices) { 1757 for (auto lbPos : lbIndices) { 1758 SmallVector<int64_t, 4> ineq; 1759 ineq.reserve(newRel.getNumCols()); 1760 int64_t lbCoeff = atIneq(lbPos, pos); 1761 // Note that in the comments above, ubCoeff is the negation of the 1762 // coefficient in the canonical form as the view taken here is that of the 1763 // term being moved to the other size of '>='. 1764 int64_t ubCoeff = -atIneq(ubPos, pos); 1765 // TODO: refactor this loop to avoid all branches inside. 1766 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1767 if (l == pos) 1768 continue; 1769 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); 1770 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); 1771 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + 1772 atIneq(lbPos, l) * (lcm / lbCoeff)); 1773 lcmProducts *= lcm; 1774 } 1775 if (darkShadow) { 1776 // The dark shadow is a convex subset of the exact integer shadow. If 1777 // there is a point here, it proves the existence of a solution. 1778 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; 1779 } 1780 // TODO: we need to have a way to add inequalities in-place in 1781 // IntegerRelation instead of creating and copying over. 1782 newRel.addInequality(ineq); 1783 } 1784 } 1785 1786 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1) 1787 << "\n"); 1788 if (lcmProducts == 1 && isResultIntegerExact) 1789 *isResultIntegerExact = true; 1790 1791 // Copy over the constraints not involving this variable. 1792 for (auto nbPos : nbIndices) { 1793 SmallVector<int64_t, 4> ineq; 1794 ineq.reserve(getNumCols() - 1); 1795 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1796 if (l == pos) 1797 continue; 1798 ineq.push_back(atIneq(nbPos, l)); 1799 } 1800 newRel.addInequality(ineq); 1801 } 1802 1803 assert(newRel.getNumConstraints() == 1804 lbIndices.size() * ubIndices.size() + nbIndices.size()); 1805 1806 // Copy over the equalities. 1807 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1808 SmallVector<int64_t, 4> eq; 1809 eq.reserve(newRel.getNumCols()); 1810 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1811 if (l == pos) 1812 continue; 1813 eq.push_back(atEq(r, l)); 1814 } 1815 newRel.addEquality(eq); 1816 } 1817 1818 // GCD tightening and normalization allows detection of more trivially 1819 // redundant constraints. 1820 newRel.gcdTightenInequalities(); 1821 newRel.normalizeConstraintsByGCD(); 1822 newRel.removeTrivialRedundancy(); 1823 clearAndCopyFrom(newRel); 1824 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1825 LLVM_DEBUG(dump()); 1826 } 1827 1828 #undef DEBUG_TYPE 1829 #define DEBUG_TYPE "presburger" 1830 1831 void IntegerRelation::projectOut(unsigned pos, unsigned num) { 1832 if (num == 0) 1833 return; 1834 1835 // 'pos' can be at most getNumCols() - 2 if num > 0. 1836 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position"); 1837 assert(pos + num < getNumCols() && "invalid range"); 1838 1839 // Eliminate as many identifiers as possible using Gaussian elimination. 1840 unsigned currentPos = pos; 1841 unsigned numToEliminate = num; 1842 unsigned numGaussianEliminated = 0; 1843 1844 while (currentPos < getNumIds()) { 1845 unsigned curNumEliminated = 1846 gaussianEliminateIds(currentPos, currentPos + numToEliminate); 1847 ++currentPos; 1848 numToEliminate -= curNumEliminated + 1; 1849 numGaussianEliminated += curNumEliminated; 1850 } 1851 1852 // Eliminate the remaining using Fourier-Motzkin. 1853 for (unsigned i = 0; i < num - numGaussianEliminated; i++) { 1854 unsigned numToEliminate = num - numGaussianEliminated - i; 1855 fourierMotzkinEliminate( 1856 getBestIdToEliminate(*this, pos, pos + numToEliminate)); 1857 } 1858 1859 // Fast/trivial simplifications. 1860 gcdTightenInequalities(); 1861 // Normalize constraints after tightening since the latter impacts this, but 1862 // not the other way round. 1863 normalizeConstraintsByGCD(); 1864 } 1865 1866 namespace { 1867 1868 enum BoundCmpResult { Greater, Less, Equal, Unknown }; 1869 1870 /// Compares two affine bounds whose coefficients are provided in 'first' and 1871 /// 'second'. The last coefficient is the constant term. 1872 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 1873 assert(a.size() == b.size()); 1874 1875 // For the bounds to be comparable, their corresponding identifier 1876 // coefficients should be equal; the constant terms are then compared to 1877 // determine less/greater/equal. 1878 1879 if (!std::equal(a.begin(), a.end() - 1, b.begin())) 1880 return Unknown; 1881 1882 if (a.back() == b.back()) 1883 return Equal; 1884 1885 return a.back() < b.back() ? Less : Greater; 1886 } 1887 } // namespace 1888 1889 // Returns constraints that are common to both A & B. 1890 static void getCommonConstraints(const IntegerRelation &a, 1891 const IntegerRelation &b, IntegerRelation &c) { 1892 c = IntegerRelation(a.getSpace()); 1893 // a naive O(n^2) check should be enough here given the input sizes. 1894 for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) { 1895 for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) { 1896 if (a.getInequality(r) == b.getInequality(s)) { 1897 c.addInequality(a.getInequality(r)); 1898 break; 1899 } 1900 } 1901 } 1902 for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) { 1903 for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) { 1904 if (a.getEquality(r) == b.getEquality(s)) { 1905 c.addEquality(a.getEquality(r)); 1906 break; 1907 } 1908 } 1909 } 1910 } 1911 1912 // Computes the bounding box with respect to 'other' by finding the min of the 1913 // lower bounds and the max of the upper bounds along each of the dimensions. 1914 LogicalResult 1915 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { 1916 assert(isSpaceEqual(otherCst) && "Spaces should match."); 1917 assert(getNumLocalIds() == 0 && "local ids not supported yet here"); 1918 1919 // Get the constraints common to both systems; these will be added as is to 1920 // the union. 1921 IntegerRelation commonCst(PresburgerSpace::getRelationSpace()); 1922 getCommonConstraints(*this, otherCst, commonCst); 1923 1924 std::vector<SmallVector<int64_t, 8>> boundingLbs; 1925 std::vector<SmallVector<int64_t, 8>> boundingUbs; 1926 boundingLbs.reserve(2 * getNumDimIds()); 1927 boundingUbs.reserve(2 * getNumDimIds()); 1928 1929 // To hold lower and upper bounds for each dimension. 1930 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; 1931 // To compute min of lower bounds and max of upper bounds for each dimension. 1932 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1); 1933 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1); 1934 // To compute final new lower and upper bounds for the union. 1935 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); 1936 1937 int64_t lbFloorDivisor, otherLbFloorDivisor; 1938 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 1939 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); 1940 if (!extent.hasValue()) 1941 // TODO: symbolic extents when necessary. 1942 // TODO: handle union if a dimension is unbounded. 1943 return failure(); 1944 1945 auto otherExtent = otherCst.getConstantBoundOnDimSize( 1946 d, &otherLb, &otherLbFloorDivisor, &otherUb); 1947 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor) 1948 // TODO: symbolic extents when necessary. 1949 return failure(); 1950 1951 assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); 1952 1953 auto res = compareBounds(lb, otherLb); 1954 // Identify min. 1955 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { 1956 minLb = lb; 1957 // Since the divisor is for a floordiv, we need to convert to ceildiv, 1958 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=> 1959 // div * i >= expr - div + 1. 1960 minLb.back() -= lbFloorDivisor - 1; 1961 } else if (res == BoundCmpResult::Greater) { 1962 minLb = otherLb; 1963 minLb.back() -= otherLbFloorDivisor - 1; 1964 } else { 1965 // Uncomparable - check for constant lower/upper bounds. 1966 auto constLb = getConstantBound(BoundType::LB, d); 1967 auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d); 1968 if (!constLb.hasValue() || !constOtherLb.hasValue()) 1969 return failure(); 1970 std::fill(minLb.begin(), minLb.end(), 0); 1971 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); 1972 } 1973 1974 // Do the same for ub's but max of upper bounds. Identify max. 1975 auto uRes = compareBounds(ub, otherUb); 1976 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { 1977 maxUb = ub; 1978 } else if (uRes == BoundCmpResult::Less) { 1979 maxUb = otherUb; 1980 } else { 1981 // Uncomparable - check for constant lower/upper bounds. 1982 auto constUb = getConstantBound(BoundType::UB, d); 1983 auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d); 1984 if (!constUb.hasValue() || !constOtherUb.hasValue()) 1985 return failure(); 1986 std::fill(maxUb.begin(), maxUb.end(), 0); 1987 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); 1988 } 1989 1990 std::fill(newLb.begin(), newLb.end(), 0); 1991 std::fill(newUb.begin(), newUb.end(), 0); 1992 1993 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, 1994 // and so it's the divisor for newLb and newUb as well. 1995 newLb[d] = lbFloorDivisor; 1996 newUb[d] = -lbFloorDivisor; 1997 // Copy over the symbolic part + constant term. 1998 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); 1999 std::transform(newLb.begin() + getNumDimIds(), newLb.end(), 2000 newLb.begin() + getNumDimIds(), std::negate<int64_t>()); 2001 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); 2002 2003 boundingLbs.push_back(newLb); 2004 boundingUbs.push_back(newUb); 2005 } 2006 2007 // Clear all constraints and add the lower/upper bounds for the bounding box. 2008 clearConstraints(); 2009 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 2010 addInequality(boundingLbs[d]); 2011 addInequality(boundingUbs[d]); 2012 } 2013 2014 // Add the constraints that were common to both systems. 2015 append(commonCst); 2016 removeTrivialRedundancy(); 2017 2018 // TODO: copy over pure symbolic constraints from this and 'other' over to the 2019 // union (since the above are just the union along dimensions); we shouldn't 2020 // be discarding any other constraints on the symbols. 2021 2022 return success(); 2023 } 2024 2025 bool IntegerRelation::isColZero(unsigned pos) const { 2026 unsigned rowPos; 2027 return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) && 2028 !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos); 2029 } 2030 2031 /// Find positions of inequalities and equalities that do not have a coefficient 2032 /// for [pos, pos + num) identifiers. 2033 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos, 2034 unsigned num, 2035 SmallVectorImpl<unsigned> &nbIneqIndices, 2036 SmallVectorImpl<unsigned> &nbEqIndices) { 2037 assert(pos < cst.getNumIds() && "invalid start position"); 2038 assert(pos + num <= cst.getNumIds() && "invalid limit"); 2039 2040 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 2041 // The bounds are to be independent of [offset, offset + num) columns. 2042 unsigned c; 2043 for (c = pos; c < pos + num; ++c) { 2044 if (cst.atIneq(r, c) != 0) 2045 break; 2046 } 2047 if (c == pos + num) 2048 nbIneqIndices.push_back(r); 2049 } 2050 2051 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 2052 // The bounds are to be independent of [offset, offset + num) columns. 2053 unsigned c; 2054 for (c = pos; c < pos + num; ++c) { 2055 if (cst.atEq(r, c) != 0) 2056 break; 2057 } 2058 if (c == pos + num) 2059 nbEqIndices.push_back(r); 2060 } 2061 } 2062 2063 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) { 2064 assert(pos + num <= getNumIds() && "invalid range"); 2065 2066 // Remove constraints that are independent of these identifiers. 2067 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices; 2068 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices); 2069 2070 // Iterate in reverse so that indices don't have to be updated. 2071 // TODO: This method can be made more efficient (because removal of each 2072 // inequality leads to much shifting/copying in the underlying buffer). 2073 for (auto nbIndex : llvm::reverse(nbIneqIndices)) 2074 removeInequality(nbIndex); 2075 for (auto nbIndex : llvm::reverse(nbEqIndices)) 2076 removeEquality(nbIndex); 2077 } 2078 2079 void IntegerRelation::printSpace(raw_ostream &os) const { 2080 PresburgerSpace::print(os); 2081 os << getNumConstraints() << " constraints\n"; 2082 } 2083 2084 void IntegerRelation::print(raw_ostream &os) const { 2085 assert(hasConsistentState()); 2086 printSpace(os); 2087 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 2088 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2089 os << atEq(i, j) << " "; 2090 } 2091 os << "= 0\n"; 2092 } 2093 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 2094 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2095 os << atIneq(i, j) << " "; 2096 } 2097 os << ">= 0\n"; 2098 } 2099 os << '\n'; 2100 } 2101 2102 void IntegerRelation::dump() const { print(llvm::errs()); } 2103 2104 unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) { 2105 assert((kind != IdKind::Domain || num == 0) && 2106 "Domain has to be zero in a set"); 2107 return IntegerRelation::insertId(kind, pos, num); 2108 } 2109