1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A class to represent an relation over integer tuples. A relation is 10 // represented as a constraint system over a space of tuples of integer valued 11 // varaiables supporting symbolic identifiers and existential quantification. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/Presburger/IntegerRelation.h" 16 #include "mlir/Analysis/Presburger/LinearTransform.h" 17 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 18 #include "mlir/Analysis/Presburger/Simplex.h" 19 #include "mlir/Analysis/Presburger/Utils.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/DenseSet.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "presburger" 25 26 using namespace mlir; 27 using namespace presburger; 28 29 using llvm::SmallDenseMap; 30 using llvm::SmallDenseSet; 31 32 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const { 33 return std::make_unique<IntegerRelation>(*this); 34 } 35 36 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const { 37 return std::make_unique<IntegerPolyhedron>(*this); 38 } 39 40 void IntegerRelation::append(const IntegerRelation &other) { 41 assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); 42 43 inequalities.reserveRows(inequalities.getNumRows() + 44 other.getNumInequalities()); 45 equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities()); 46 47 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { 48 addInequality(other.getInequality(r)); 49 } 50 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { 51 addEquality(other.getEquality(r)); 52 } 53 } 54 55 bool IntegerRelation::isEqual(const IntegerRelation &other) const { 56 assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); 57 return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); 58 } 59 60 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const { 61 assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); 62 return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other)); 63 } 64 65 MaybeOptimum<SmallVector<Fraction, 8>> 66 IntegerRelation::findRationalLexMin() const { 67 assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); 68 MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin = 69 LexSimplex(*this).findRationalLexMin(); 70 71 if (!maybeLexMin.isBounded()) 72 return maybeLexMin; 73 74 // The Simplex returns the lexmin over all the variables including locals. But 75 // locals are not actually part of the space and should not be returned in the 76 // result. Since the locals are placed last in the list of identifiers, they 77 // will be minimized last in the lexmin. So simply truncating out the locals 78 // from the end of the answer gives the desired lexmin over the dimensions. 79 assert(maybeLexMin->size() == getNumIds() && 80 "Incorrect number of vars in lexMin!"); 81 maybeLexMin->resize(getNumDimAndSymbolIds()); 82 return maybeLexMin; 83 } 84 85 MaybeOptimum<SmallVector<int64_t, 8>> 86 IntegerRelation::findIntegerLexMin() const { 87 assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); 88 MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin = 89 LexSimplex(*this).findIntegerLexMin(); 90 91 if (!maybeLexMin.isBounded()) 92 return maybeLexMin.getKind(); 93 94 // The Simplex returns the lexmin over all the variables including locals. But 95 // locals are not actually part of the space and should not be returned in the 96 // result. Since the locals are placed last in the list of identifiers, they 97 // will be minimized last in the lexmin. So simply truncating out the locals 98 // from the end of the answer gives the desired lexmin over the dimensions. 99 assert(maybeLexMin->size() == getNumIds() && 100 "Incorrect number of vars in lexMin!"); 101 maybeLexMin->resize(getNumDimAndSymbolIds()); 102 return maybeLexMin; 103 } 104 105 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { 106 assert(pos <= getNumIdKind(kind)); 107 108 unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num); 109 inequalities.insertColumns(insertPos, num); 110 equalities.insertColumns(insertPos, num); 111 return insertPos; 112 } 113 114 unsigned IntegerRelation::appendId(IdKind kind, unsigned num) { 115 unsigned pos = getNumIdKind(kind); 116 return insertId(kind, pos, num); 117 } 118 119 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) { 120 assert(eq.size() == getNumCols()); 121 unsigned row = equalities.appendExtraRow(); 122 for (unsigned i = 0, e = eq.size(); i < e; ++i) 123 equalities(row, i) = eq[i]; 124 } 125 126 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) { 127 assert(inEq.size() == getNumCols()); 128 unsigned row = inequalities.appendExtraRow(); 129 for (unsigned i = 0, e = inEq.size(); i < e; ++i) 130 inequalities(row, i) = inEq[i]; 131 } 132 133 void IntegerRelation::removeId(IdKind kind, unsigned pos) { 134 removeIdRange(kind, pos, pos + 1); 135 } 136 137 void IntegerRelation::removeId(unsigned pos) { removeIdRange(pos, pos + 1); } 138 139 void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart, 140 unsigned idLimit) { 141 assert(idLimit <= getNumIdKind(kind)); 142 143 if (idStart >= idLimit) 144 return; 145 146 // Remove eliminated identifiers from the constraints. 147 unsigned offset = getIdKindOffset(kind); 148 equalities.removeColumns(offset + idStart, idLimit - idStart); 149 inequalities.removeColumns(offset + idStart, idLimit - idStart); 150 151 // Remove eliminated identifiers from the space. 152 PresburgerLocalSpace::removeIdRange(kind, idStart, idLimit); 153 } 154 155 void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) { 156 assert(idLimit <= getNumIds()); 157 158 if (idStart >= idLimit) 159 return; 160 161 // Helper function to remove ids of the specified kind in the given range 162 // [start, limit), The range is absolute (i.e. it is not relative to the kind 163 // of identifier). Also updates `limit` to reflect the deleted identifiers. 164 auto removeIdKindInRange = [this](IdKind kind, unsigned &start, 165 unsigned &limit) { 166 if (start >= limit) 167 return; 168 169 unsigned offset = getIdKindOffset(kind); 170 unsigned num = getNumIdKind(kind); 171 172 // Get `start`, `limit` relative to the specified kind. 173 unsigned relativeStart = 174 start <= offset ? 0 : std::min(num, start - offset); 175 unsigned relativeLimit = 176 limit <= offset ? 0 : std::min(num, limit - offset); 177 178 // Remove ids of the specified kind in the relative range. 179 removeIdRange(kind, relativeStart, relativeLimit); 180 181 // Update `limit` to reflect deleted identifiers. 182 // `start` does not need to be updated because any identifiers that are 183 // deleted are after position `start`. 184 limit -= relativeLimit - relativeStart; 185 }; 186 187 removeIdKindInRange(IdKind::Domain, idStart, idLimit); 188 removeIdKindInRange(IdKind::Range, idStart, idLimit); 189 removeIdKindInRange(IdKind::Symbol, idStart, idLimit); 190 removeIdKindInRange(IdKind::Local, idStart, idLimit); 191 } 192 193 void IntegerRelation::removeEquality(unsigned pos) { 194 equalities.removeRow(pos); 195 } 196 197 void IntegerRelation::removeInequality(unsigned pos) { 198 inequalities.removeRow(pos); 199 } 200 201 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) { 202 if (start >= end) 203 return; 204 equalities.removeRows(start, end - start); 205 } 206 207 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) { 208 if (start >= end) 209 return; 210 inequalities.removeRows(start, end - start); 211 } 212 213 void IntegerRelation::swapId(unsigned posA, unsigned posB) { 214 assert(posA < getNumIds() && "invalid position A"); 215 assert(posB < getNumIds() && "invalid position B"); 216 217 if (posA == posB) 218 return; 219 220 inequalities.swapColumns(posA, posB); 221 equalities.swapColumns(posA, posB); 222 } 223 224 void IntegerRelation::clearConstraints() { 225 equalities.resizeVertically(0); 226 inequalities.resizeVertically(0); 227 } 228 229 /// Gather all lower and upper bounds of the identifier at `pos`, and 230 /// optionally any equalities on it. In addition, the bounds are to be 231 /// independent of identifiers in position range [`offset`, `offset` + `num`). 232 void IntegerRelation::getLowerAndUpperBoundIndices( 233 unsigned pos, SmallVectorImpl<unsigned> *lbIndices, 234 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices, 235 unsigned offset, unsigned num) const { 236 assert(pos < getNumIds() && "invalid position"); 237 assert(offset + num < getNumCols() && "invalid range"); 238 239 // Checks for a constraint that has a non-zero coeff for the identifiers in 240 // the position range [offset, offset + num) while ignoring `pos`. 241 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) { 242 unsigned c, f; 243 auto cst = isEq ? getEquality(r) : getInequality(r); 244 for (c = offset, f = offset + num; c < f; ++c) { 245 if (c == pos) 246 continue; 247 if (cst[c] != 0) 248 break; 249 } 250 return c < f; 251 }; 252 253 // Gather all lower bounds and upper bounds of the variable. Since the 254 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 255 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 256 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 257 // The bounds are to be independent of [offset, offset + num) columns. 258 if (containsConstraintDependentOnRange(r, /*isEq=*/false)) 259 continue; 260 if (atIneq(r, pos) >= 1) { 261 // Lower bound. 262 lbIndices->push_back(r); 263 } else if (atIneq(r, pos) <= -1) { 264 // Upper bound. 265 ubIndices->push_back(r); 266 } 267 } 268 269 // An equality is both a lower and upper bound. Record any equalities 270 // involving the pos^th identifier. 271 if (!eqIndices) 272 return; 273 274 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 275 if (atEq(r, pos) == 0) 276 continue; 277 if (containsConstraintDependentOnRange(r, /*isEq=*/true)) 278 continue; 279 eqIndices->push_back(r); 280 } 281 } 282 283 bool IntegerRelation::hasConsistentState() const { 284 if (!inequalities.hasConsistentState()) 285 return false; 286 if (!equalities.hasConsistentState()) 287 return false; 288 return true; 289 } 290 291 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) { 292 if (values.empty()) 293 return; 294 assert(pos + values.size() <= getNumIds() && 295 "invalid position or too many values"); 296 // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the 297 // constant term and removing the id x_j. We do this for all the ids 298 // pos, pos + 1, ... pos + values.size() - 1. 299 unsigned constantColPos = getNumCols() - 1; 300 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 301 inequalities.addToColumn(i + pos, constantColPos, values[i]); 302 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 303 equalities.addToColumn(i + pos, constantColPos, values[i]); 304 removeIdRange(pos, pos + values.size()); 305 } 306 307 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) { 308 *this = other; 309 } 310 311 // Searches for a constraint with a non-zero coefficient at `colIdx` in 312 // equality (isEq=true) or inequality (isEq=false) constraints. 313 // Returns true and sets row found in search in `rowIdx`, false otherwise. 314 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, 315 unsigned *rowIdx) const { 316 assert(colIdx < getNumCols() && "position out of bounds"); 317 auto at = [&](unsigned rowIdx) -> int64_t { 318 return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); 319 }; 320 unsigned e = isEq ? getNumEqualities() : getNumInequalities(); 321 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { 322 if (at(*rowIdx) != 0) { 323 return true; 324 } 325 } 326 return false; 327 } 328 329 void IntegerRelation::normalizeConstraintsByGCD() { 330 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 331 equalities.normalizeRow(i); 332 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 333 inequalities.normalizeRow(i); 334 } 335 336 bool IntegerRelation::hasInvalidConstraint() const { 337 assert(hasConsistentState()); 338 auto check = [&](bool isEq) -> bool { 339 unsigned numCols = getNumCols(); 340 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); 341 for (unsigned i = 0, e = numRows; i < e; ++i) { 342 unsigned j; 343 for (j = 0; j < numCols - 1; ++j) { 344 int64_t v = isEq ? atEq(i, j) : atIneq(i, j); 345 // Skip rows with non-zero variable coefficients. 346 if (v != 0) 347 break; 348 } 349 if (j < numCols - 1) { 350 continue; 351 } 352 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. 353 // Example invalid constraints include: '1 == 0' or '-1 >= 0' 354 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); 355 if ((isEq && v != 0) || (!isEq && v < 0)) { 356 return true; 357 } 358 } 359 return false; 360 }; 361 if (check(/*isEq=*/true)) 362 return true; 363 return check(/*isEq=*/false); 364 } 365 366 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at 367 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be 368 /// updated as they have already been eliminated. 369 static void eliminateFromConstraint(IntegerRelation *constraints, 370 unsigned rowIdx, unsigned pivotRow, 371 unsigned pivotCol, unsigned elimColStart, 372 bool isEq) { 373 // Skip if equality 'rowIdx' if same as 'pivotRow'. 374 if (isEq && rowIdx == pivotRow) 375 return; 376 auto at = [&](unsigned i, unsigned j) -> int64_t { 377 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); 378 }; 379 int64_t leadCoeff = at(rowIdx, pivotCol); 380 // Skip if leading coefficient at 'rowIdx' is already zero. 381 if (leadCoeff == 0) 382 return; 383 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); 384 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; 385 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); 386 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); 387 int64_t rowMultiplier = lcm / std::abs(leadCoeff); 388 389 unsigned numCols = constraints->getNumCols(); 390 for (unsigned j = 0; j < numCols; ++j) { 391 // Skip updating column 'j' if it was just eliminated. 392 if (j >= elimColStart && j < pivotCol) 393 continue; 394 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + 395 rowMultiplier * at(rowIdx, j); 396 isEq ? constraints->atEq(rowIdx, j) = v 397 : constraints->atIneq(rowIdx, j) = v; 398 } 399 } 400 401 /// Returns the position of the identifier that has the minimum <number of lower 402 /// bounds> times <number of upper bounds> from the specified range of 403 /// identifiers [start, end). It is often best to eliminate in the increasing 404 /// order of these counts when doing Fourier-Motzkin elimination since FM adds 405 /// that many new constraints. 406 static unsigned getBestIdToEliminate(const IntegerRelation &cst, unsigned start, 407 unsigned end) { 408 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); 409 410 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { 411 unsigned numLb = 0; 412 unsigned numUb = 0; 413 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 414 if (cst.atIneq(r, pos) > 0) { 415 ++numLb; 416 } else if (cst.atIneq(r, pos) < 0) { 417 ++numUb; 418 } 419 } 420 return numLb * numUb; 421 }; 422 423 unsigned minLoc = start; 424 unsigned min = getProductOfNumLowerUpperBounds(start); 425 for (unsigned c = start + 1; c < end; c++) { 426 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); 427 if (numLbUbProduct < min) { 428 min = numLbUbProduct; 429 minLoc = c; 430 } 431 } 432 return minLoc; 433 } 434 435 // Checks for emptiness of the set by eliminating identifiers successively and 436 // using the GCD test (on all equality constraints) and checking for trivially 437 // invalid constraints. Returns 'true' if the constraint system is found to be 438 // empty; false otherwise. 439 bool IntegerRelation::isEmpty() const { 440 if (isEmptyByGCDTest() || hasInvalidConstraint()) 441 return true; 442 443 IntegerRelation tmpCst(*this); 444 445 // First, eliminate as many local variables as possible using equalities. 446 tmpCst.removeRedundantLocalVars(); 447 if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) 448 return true; 449 450 // Eliminate as many identifiers as possible using Gaussian elimination. 451 unsigned currentPos = 0; 452 while (currentPos < tmpCst.getNumIds()) { 453 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); 454 ++currentPos; 455 // We check emptiness through trivial checks after eliminating each ID to 456 // detect emptiness early. Since the checks isEmptyByGCDTest() and 457 // hasInvalidConstraint() are linear time and single sweep on the constraint 458 // buffer, this appears reasonable - but can optimize in the future. 459 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) 460 return true; 461 } 462 463 // Eliminate the remaining using FM. 464 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { 465 tmpCst.fourierMotzkinEliminate( 466 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); 467 // Check for a constraint explosion. This rarely happens in practice, but 468 // this check exists as a safeguard against improperly constructed 469 // constraint systems or artificially created arbitrarily complex systems 470 // that aren't the intended use case for IntegerRelation. This is 471 // needed since FM has a worst case exponential complexity in theory. 472 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { 473 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n"); 474 return false; 475 } 476 477 // FM wouldn't have modified the equalities in any way. So no need to again 478 // run GCD test. Check for trivial invalid constraints. 479 if (tmpCst.hasInvalidConstraint()) 480 return true; 481 } 482 return false; 483 } 484 485 // Runs the GCD test on all equality constraints. Returns 'true' if this test 486 // fails on any equality. Returns 'false' otherwise. 487 // This test can be used to disprove the existence of a solution. If it returns 488 // true, no integer solution to the equality constraints can exist. 489 // 490 // GCD test definition: 491 // 492 // The equality constraint: 493 // 494 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 495 // 496 // has an integer solution iff: 497 // 498 // GCD of c_1, c_2, ..., c_n divides c_0. 499 // 500 bool IntegerRelation::isEmptyByGCDTest() const { 501 assert(hasConsistentState()); 502 unsigned numCols = getNumCols(); 503 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 504 uint64_t gcd = std::abs(atEq(i, 0)); 505 for (unsigned j = 1; j < numCols - 1; ++j) { 506 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); 507 } 508 int64_t v = std::abs(atEq(i, numCols - 1)); 509 if (gcd > 0 && (v % gcd != 0)) { 510 return true; 511 } 512 } 513 return false; 514 } 515 516 // Returns a matrix where each row is a vector along which the polytope is 517 // bounded. The span of the returned vectors is guaranteed to contain all 518 // such vectors. The returned vectors are NOT guaranteed to be linearly 519 // independent. This function should not be called on empty sets. 520 // 521 // It is sufficient to check the perpendiculars of the constraints, as the set 522 // of perpendiculars which are bounded must span all bounded directions. 523 Matrix IntegerRelation::getBoundedDirections() const { 524 // Note that it is necessary to add the equalities too (which the constructor 525 // does) even though we don't need to check if they are bounded; whether an 526 // inequality is bounded or not depends on what other constraints, including 527 // equalities, are present. 528 Simplex simplex(*this); 529 530 assert(!simplex.isEmpty() && "It is not meaningful to ask whether a " 531 "direction is bounded in an empty set."); 532 533 SmallVector<unsigned, 8> boundedIneqs; 534 // The constructor adds the inequalities to the simplex first, so this 535 // processes all the inequalities. 536 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 537 if (simplex.isBoundedAlongConstraint(i)) 538 boundedIneqs.push_back(i); 539 } 540 541 // The direction vector is given by the coefficients and does not include the 542 // constant term, so the matrix has one fewer column. 543 unsigned dirsNumCols = getNumCols() - 1; 544 Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols); 545 546 // Copy the bounded inequalities. 547 unsigned row = 0; 548 for (unsigned i : boundedIneqs) { 549 for (unsigned col = 0; col < dirsNumCols; ++col) 550 dirs(row, col) = atIneq(i, col); 551 ++row; 552 } 553 554 // Copy the equalities. All the equalities' perpendiculars are bounded. 555 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 556 for (unsigned col = 0; col < dirsNumCols; ++col) 557 dirs(row, col) = atEq(i, col); 558 ++row; 559 } 560 561 return dirs; 562 } 563 564 bool eqInvolvesSuffixDims(const IntegerRelation &rel, unsigned eqIndex, 565 unsigned numDims) { 566 for (unsigned e = rel.getNumIds(), j = e - numDims; j < e; ++j) 567 if (rel.atEq(eqIndex, j) != 0) 568 return true; 569 return false; 570 } 571 bool ineqInvolvesSuffixDims(const IntegerRelation &rel, unsigned ineqIndex, 572 unsigned numDims) { 573 for (unsigned e = rel.getNumIds(), j = e - numDims; j < e; ++j) 574 if (rel.atIneq(ineqIndex, j) != 0) 575 return true; 576 return false; 577 } 578 579 void removeConstraintsInvolvingSuffixDims(IntegerRelation &rel, 580 unsigned unboundedDims) { 581 // We iterate backwards so that whether we remove constraint i - 1 or not, the 582 // next constraint to be tested is always i - 2. 583 for (unsigned i = rel.getNumEqualities(); i > 0; i--) 584 if (eqInvolvesSuffixDims(rel, i - 1, unboundedDims)) 585 rel.removeEquality(i - 1); 586 for (unsigned i = rel.getNumInequalities(); i > 0; i--) 587 if (ineqInvolvesSuffixDims(rel, i - 1, unboundedDims)) 588 rel.removeInequality(i - 1); 589 } 590 591 bool IntegerRelation::isIntegerEmpty() const { 592 return !findIntegerSample().hasValue(); 593 } 594 595 /// Let this set be S. If S is bounded then we directly call into the GBR 596 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e., 597 /// vectors v such that S extends to infinity along v or -v. In this case we 598 /// use an algorithm described in the integer set library (isl) manual and used 599 /// by the isl_set_sample function in that library. The algorithm is: 600 /// 601 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all 602 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the 603 /// dimensions. 604 /// 605 /// 2) Construct a set B by removing all constraints that involve 606 /// the unbounded dimensions and then deleting the unbounded dimensions. Note 607 /// that B is a Bounded set. 608 /// 609 /// 3) Try to obtain a sample from B using the GBR sampling 610 /// algorithm. If no sample is found, return that S is empty. 611 /// 612 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set 613 /// C. C is a full-dimensional Cone and always contains a sample. 614 /// 615 /// 5) Obtain an integer sample from C. 616 /// 617 /// 6) Return T*v, where v is the concatenation of the samples from B and C. 618 /// 619 /// The following is a sketch of a proof that 620 /// a) If the algorithm returns empty, then S is empty. 621 /// b) If the algorithm returns a sample, it is a valid sample in S. 622 /// 623 /// The algorithm returns empty only if B is empty, in which case S*T is 624 /// certainly empty since B was obtained by removing constraints and then 625 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector 626 /// v is in S*T iff T*v is in S. So in this case, since 627 /// S*T is empty, S is empty too. 628 /// 629 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the 630 /// constraints of S*T that did not involve unbounded dimensions are satisfied 631 /// by this substitution. All dimensions in the linear span of the dimensions 632 /// outside the prefix are unbounded in S*T (step 1). Substituting values for 633 /// the bounded dimensions cannot make these dimensions bounded, and these are 634 /// the only remaining dimensions in C, so C is unbounded along every vector (in 635 /// the positive or negative direction, or both). C is hence a full-dimensional 636 /// cone and therefore always contains an integer point. 637 /// 638 /// Concatenating the samples from B and C gives a sample v in S*T, so the 639 /// returned sample T*v is a sample in S. 640 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const { 641 // First, try the GCD test heuristic. 642 if (isEmptyByGCDTest()) 643 return {}; 644 645 Simplex simplex(*this); 646 if (simplex.isEmpty()) 647 return {}; 648 649 // For a bounded set, we directly call into the GBR sampling algorithm. 650 if (!simplex.isUnbounded()) 651 return simplex.findIntegerSample(); 652 653 // The set is unbounded. We cannot directly use the GBR algorithm. 654 // 655 // m is a matrix containing, in each row, a vector in which S is 656 // bounded, such that the linear span of all these dimensions contains all 657 // bounded dimensions in S. 658 Matrix m = getBoundedDirections(); 659 // In column echelon form, each row of m occupies only the first rank(m) 660 // columns and has zeros on the other columns. The transform T that brings S 661 // to column echelon form is unimodular as well, so this is a suitable 662 // transform to use in step 1 of the algorithm. 663 std::pair<unsigned, LinearTransform> result = 664 LinearTransform::makeTransformToColumnEchelon(std::move(m)); 665 const LinearTransform &transform = result.second; 666 // 1) Apply T to S to obtain S*T. 667 IntegerRelation transformedSet = transform.applyTo(*this); 668 669 // 2) Remove the unbounded dimensions and constraints involving them to 670 // obtain a bounded set. 671 IntegerRelation boundedSet(transformedSet); 672 unsigned numBoundedDims = result.first; 673 unsigned numUnboundedDims = getNumIds() - numBoundedDims; 674 removeConstraintsInvolvingSuffixDims(boundedSet, numUnboundedDims); 675 boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds()); 676 677 // 3) Try to obtain a sample from the bounded set. 678 Optional<SmallVector<int64_t, 8>> boundedSample = 679 Simplex(boundedSet).findIntegerSample(); 680 if (!boundedSample) 681 return {}; 682 assert(boundedSet.containsPoint(*boundedSample) && 683 "Simplex returned an invalid sample!"); 684 685 // 4) Substitute the values of the bounded dimensions into S*T to obtain a 686 // full-dimensional cone, which necessarily contains an integer sample. 687 transformedSet.setAndEliminate(0, *boundedSample); 688 IntegerRelation &cone = transformedSet; 689 690 // 5) Obtain an integer sample from the cone. 691 // 692 // We shrink the cone such that for any rational point in the shrunken cone, 693 // rounding up each of the point's coordinates produces a point that still 694 // lies in the original cone. 695 // 696 // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i. 697 // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the 698 // shrunken cone will have the inequality tightened by some amount s, such 699 // that if x satisfies the shrunken cone's tightened inequality, then x + e 700 // satisfies the original inequality, i.e., 701 // 702 // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0 703 // 704 // for any e_i values in [0, 1). In fact, we will handle the slightly more 705 // general case where e_i can be in [0, 1]. For example, consider the 706 // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low 707 // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS 708 // is minimized when we add 1 to the x_i with negative coefficient a_i and 709 // keep the other x_i the same. In the example, we would get x = (3, 1, 1), 710 // changing the value of the LHS by -3 + -7 = -10. 711 // 712 // In general, the value of the LHS can change by at most the sum of the 713 // negative a_i, so we accomodate this by shifting the inequality by this 714 // amount for the shrunken cone. 715 for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) { 716 for (unsigned j = 0; j < cone.getNumIds(); ++j) { 717 int64_t coeff = cone.atIneq(i, j); 718 if (coeff < 0) 719 cone.atIneq(i, cone.getNumIds()) += coeff; 720 } 721 } 722 723 // Obtain an integer sample in the cone by rounding up a rational point from 724 // the shrunken cone. Shrinking the cone amounts to shifting its apex 725 // "inwards" without changing its "shape"; the shrunken cone is still a 726 // full-dimensional cone and is hence non-empty. 727 Simplex shrunkenConeSimplex(cone); 728 assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!"); 729 730 // The sample will always exist since the shrunken cone is non-empty. 731 SmallVector<Fraction, 8> shrunkenConeSample = 732 *shrunkenConeSimplex.getRationalSample(); 733 734 SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil)); 735 736 // 6) Return transform * concat(boundedSample, coneSample). 737 SmallVector<int64_t, 8> &sample = boundedSample.getValue(); 738 sample.append(coneSample.begin(), coneSample.end()); 739 return transform.postMultiplyWithColumn(sample); 740 } 741 742 /// Helper to evaluate an affine expression at a point. 743 /// The expression is a list of coefficients for the dimensions followed by the 744 /// constant term. 745 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) { 746 assert(expr.size() == 1 + point.size() && 747 "Dimensionalities of point and expression don't match!"); 748 int64_t value = expr.back(); 749 for (unsigned i = 0; i < point.size(); ++i) 750 value += expr[i] * point[i]; 751 return value; 752 } 753 754 /// A point satisfies an equality iff the value of the equality at the 755 /// expression is zero, and it satisfies an inequality iff the value of the 756 /// inequality at that point is non-negative. 757 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const { 758 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 759 if (valueAt(getEquality(i), point) != 0) 760 return false; 761 } 762 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 763 if (valueAt(getInequality(i), point) < 0) 764 return false; 765 } 766 return true; 767 } 768 769 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const { 770 std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds()); 771 SmallVector<unsigned, 4> denominators(getNumLocalIds()); 772 getLocalReprs(dividends, denominators, repr); 773 } 774 775 void IntegerRelation::getLocalReprs( 776 std::vector<SmallVector<int64_t, 8>> ÷nds, 777 SmallVector<unsigned, 4> &denominators) const { 778 std::vector<MaybeLocalRepr> repr(getNumLocalIds()); 779 getLocalReprs(dividends, denominators, repr); 780 } 781 782 void IntegerRelation::getLocalReprs( 783 std::vector<SmallVector<int64_t, 8>> ÷nds, 784 SmallVector<unsigned, 4> &denominators, 785 std::vector<MaybeLocalRepr> &repr) const { 786 787 repr.resize(getNumLocalIds()); 788 dividends.resize(getNumLocalIds()); 789 denominators.resize(getNumLocalIds()); 790 791 SmallVector<bool, 8> foundRepr(getNumIds(), false); 792 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) 793 foundRepr[i] = true; 794 795 unsigned divOffset = getNumDimAndSymbolIds(); 796 bool changed; 797 do { 798 // Each time changed is true, at end of this iteration, one or more local 799 // vars have been detected as floor divs. 800 changed = false; 801 for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { 802 if (!foundRepr[i + divOffset]) { 803 MaybeLocalRepr res = computeSingleVarRepr( 804 *this, foundRepr, divOffset + i, dividends[i], denominators[i]); 805 if (!res) 806 continue; 807 foundRepr[i + divOffset] = true; 808 repr[i] = res; 809 changed = true; 810 } 811 } 812 } while (changed); 813 814 // Set 0 denominator for identifiers for which no division representation 815 // could be found. 816 for (unsigned i = 0, e = repr.size(); i < e; ++i) 817 if (!repr[i]) 818 denominators[i] = 0; 819 } 820 821 /// Tightens inequalities given that we are dealing with integer spaces. This is 822 /// analogous to the GCD test but applied to inequalities. The constant term can 823 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., 824 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a 825 /// fast method - linear in the number of coefficients. 826 // Example on how this affects practical cases: consider the scenario: 827 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield 828 // j >= 100 instead of the tighter (exact) j >= 128. 829 void IntegerRelation::gcdTightenInequalities() { 830 unsigned numCols = getNumCols(); 831 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 832 // Normalize the constraint and tighten the constant term by the GCD. 833 uint64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1); 834 if (gcd > 1) 835 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd); 836 } 837 } 838 839 // Eliminates all identifier variables in column range [posStart, posLimit). 840 // Returns the number of variables eliminated. 841 unsigned IntegerRelation::gaussianEliminateIds(unsigned posStart, 842 unsigned posLimit) { 843 // Return if identifier positions to eliminate are out of range. 844 assert(posLimit <= getNumIds()); 845 assert(hasConsistentState()); 846 847 if (posStart >= posLimit) 848 return 0; 849 850 gcdTightenInequalities(); 851 852 unsigned pivotCol = 0; 853 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { 854 // Find a row which has a non-zero coefficient in column 'j'. 855 unsigned pivotRow; 856 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) { 857 // No pivot row in equalities with non-zero at 'pivotCol'. 858 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) { 859 // If inequalities are also non-zero in 'pivotCol', it can be 860 // eliminated. 861 continue; 862 } 863 break; 864 } 865 866 // Eliminate identifier at 'pivotCol' from each equality row. 867 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 868 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 869 /*isEq=*/true); 870 equalities.normalizeRow(i); 871 } 872 873 // Eliminate identifier at 'pivotCol' from each inequality row. 874 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 875 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 876 /*isEq=*/false); 877 inequalities.normalizeRow(i); 878 } 879 removeEquality(pivotRow); 880 gcdTightenInequalities(); 881 } 882 // Update position limit based on number eliminated. 883 posLimit = pivotCol; 884 // Remove eliminated columns from all constraints. 885 removeIdRange(posStart, posLimit); 886 return posLimit - posStart; 887 } 888 889 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin 890 // to check if a constraint is redundant. 891 void IntegerRelation::removeRedundantInequalities() { 892 SmallVector<bool, 32> redun(getNumInequalities(), false); 893 // To check if an inequality is redundant, we replace the inequality by its 894 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting 895 // system is empty. If it is, the inequality is redundant. 896 IntegerRelation tmpCst(*this); 897 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 898 // Change the inequality to its complement. 899 tmpCst.inequalities.negateRow(r); 900 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; 901 if (tmpCst.isEmpty()) { 902 redun[r] = true; 903 // Zero fill the redundant inequality. 904 inequalities.fillRow(r, /*value=*/0); 905 tmpCst.inequalities.fillRow(r, /*value=*/0); 906 } else { 907 // Reverse the change (to avoid recreating tmpCst each time). 908 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; 909 tmpCst.inequalities.negateRow(r); 910 } 911 } 912 913 unsigned pos = 0; 914 for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { 915 if (!redun[r]) 916 inequalities.copyRow(r, pos++); 917 } 918 inequalities.resizeVertically(pos); 919 } 920 921 // A more complex check to eliminate redundant inequalities and equalities. Uses 922 // Simplex to check if a constraint is redundant. 923 void IntegerRelation::removeRedundantConstraints() { 924 // First, we run gcdTightenInequalities. This allows us to catch some 925 // constraints which are not redundant when considering rational solutions 926 // but are redundant in terms of integer solutions. 927 gcdTightenInequalities(); 928 Simplex simplex(*this); 929 simplex.detectRedundant(); 930 931 unsigned pos = 0; 932 unsigned numIneqs = getNumInequalities(); 933 // Scan to get rid of all inequalities marked redundant, in-place. In Simplex, 934 // the first constraints added are the inequalities. 935 for (unsigned r = 0; r < numIneqs; r++) { 936 if (!simplex.isMarkedRedundant(r)) 937 inequalities.copyRow(r, pos++); 938 } 939 inequalities.resizeVertically(pos); 940 941 // Scan to get rid of all equalities marked redundant, in-place. In Simplex, 942 // after the inequalities, a pair of constraints for each equality is added. 943 // An equality is redundant if both the inequalities in its pair are 944 // redundant. 945 pos = 0; 946 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 947 if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) && 948 simplex.isMarkedRedundant(numIneqs + 2 * r + 1))) 949 equalities.copyRow(r, pos++); 950 } 951 equalities.resizeVertically(pos); 952 } 953 954 Optional<uint64_t> IntegerRelation::computeVolume() const { 955 assert(getNumSymbolIds() == 0 && "Symbols are not yet supported!"); 956 957 Simplex simplex(*this); 958 // If the polytope is rationally empty, there are certainly no integer 959 // points. 960 if (simplex.isEmpty()) 961 return 0; 962 963 // Just find the maximum and minimum integer value of each non-local id 964 // separately, thus finding the number of integer values each such id can 965 // take. Multiplying these together gives a valid overapproximation of the 966 // number of integer points in the relation. The result this gives is 967 // equivalent to projecting (rationally) the relation onto its non-local ids 968 // and returning the number of integer points in a minimal axis-parallel 969 // hyperrectangular overapproximation of that. 970 // 971 // We also handle the special case where one dimension is unbounded and 972 // another dimension can take no integer values. In this case, the volume is 973 // zero. 974 // 975 // If there is no such empty dimension, if any dimension is unbounded we 976 // just return the result as unbounded. 977 uint64_t count = 1; 978 SmallVector<int64_t, 8> dim(getNumIds() + 1); 979 bool hasUnboundedId = false; 980 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) { 981 dim[i] = 1; 982 MaybeOptimum<int64_t> min, max; 983 std::tie(min, max) = simplex.computeIntegerBounds(dim); 984 dim[i] = 0; 985 986 assert((!min.isEmpty() && !max.isEmpty()) && 987 "Polytope should be rationally non-empty!"); 988 989 // One of the dimensions is unbounded. Note this fact. We will return 990 // unbounded if none of the other dimensions makes the volume zero. 991 if (min.isUnbounded() || max.isUnbounded()) { 992 hasUnboundedId = true; 993 continue; 994 } 995 996 // In this case there are no valid integer points and the volume is 997 // definitely zero. 998 if (min.getBoundedOptimum() > max.getBoundedOptimum()) 999 return 0; 1000 1001 count *= (*max - *min + 1); 1002 } 1003 1004 if (count == 0) 1005 return 0; 1006 if (hasUnboundedId) 1007 return {}; 1008 return count; 1009 } 1010 1011 void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) { 1012 assert(posA < getNumLocalIds() && "Invalid local id position"); 1013 assert(posB < getNumLocalIds() && "Invalid local id position"); 1014 1015 unsigned localOffset = getIdKindOffset(IdKind::Local); 1016 posA += localOffset; 1017 posB += localOffset; 1018 inequalities.addToColumn(posB, posA, 1); 1019 equalities.addToColumn(posB, posA, 1); 1020 removeId(posB); 1021 } 1022 1023 /// Adds additional local ids to the sets such that they both have the union 1024 /// of the local ids in each set, without changing the set of points that 1025 /// lie in `this` and `other`. 1026 /// 1027 /// To detect local ids that always take the same in both sets, each local id is 1028 /// represented as a floordiv with constant denominator in terms of other ids. 1029 /// After extracting these divisions, local ids with the same division 1030 /// representation are considered duplicate and are merged. It is possible that 1031 /// division representation for some local id cannot be obtained, and thus these 1032 /// local ids are not considered for detecting duplicates. 1033 void IntegerRelation::mergeLocalIds(IntegerRelation &other) { 1034 assert(PresburgerSpace::isEqual(other) && "Spaces should match."); 1035 1036 IntegerRelation &relA = *this; 1037 IntegerRelation &relB = other; 1038 1039 // Merge local ids of relA and relB without using division information, 1040 // i.e. append local ids of `relB` to `relA` and insert local ids of `relA` 1041 // to `relB` at start of its local ids. 1042 unsigned initLocals = relA.getNumLocalIds(); 1043 insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds()); 1044 relB.insertId(IdKind::Local, 0, initLocals); 1045 1046 // Get division representations from each rel. 1047 std::vector<SmallVector<int64_t, 8>> divsA, divsB; 1048 SmallVector<unsigned, 4> denomsA, denomsB; 1049 relA.getLocalReprs(divsA, denomsA); 1050 relB.getLocalReprs(divsB, denomsB); 1051 1052 // Copy division information for relB into `divsA` and `denomsA`, so that 1053 // these have the combined division information of both rels. Since newly 1054 // added local variables in relA and relB have no constraints, they will not 1055 // have any division representation. 1056 std::copy(divsB.begin() + initLocals, divsB.end(), 1057 divsA.begin() + initLocals); 1058 std::copy(denomsB.begin() + initLocals, denomsB.end(), 1059 denomsA.begin() + initLocals); 1060 1061 // Merge function that merges the local variables in both sets by treating 1062 // them as the same identifier. 1063 auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool { 1064 relA.eliminateRedundantLocalId(i, j); 1065 relB.eliminateRedundantLocalId(i, j); 1066 return true; 1067 }; 1068 1069 // Merge all divisions by removing duplicate divisions. 1070 unsigned localOffset = getIdKindOffset(IdKind::Local); 1071 removeDuplicateDivs(divsA, denomsA, localOffset, merge); 1072 } 1073 1074 /// Removes local variables using equalities. Each equality is checked if it 1075 /// can be reduced to the form: `e = affine-expr`, where `e` is a local 1076 /// variable and `affine-expr` is an affine expression not containing `e`. 1077 /// If an equality satisfies this form, the local variable is replaced in 1078 /// each constraint and then removed. The equality used to replace this local 1079 /// variable is also removed. 1080 void IntegerRelation::removeRedundantLocalVars() { 1081 // Normalize the equality constraints to reduce coefficients of local 1082 // variables to 1 wherever possible. 1083 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1084 equalities.normalizeRow(i); 1085 1086 while (true) { 1087 unsigned i, e, j, f; 1088 for (i = 0, e = getNumEqualities(); i < e; ++i) { 1089 // Find a local variable to eliminate using ith equality. 1090 for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) 1091 if (std::abs(atEq(i, j)) == 1) 1092 break; 1093 1094 // Local variable can be eliminated using ith equality. 1095 if (j < f) 1096 break; 1097 } 1098 1099 // No equality can be used to eliminate a local variable. 1100 if (i == e) 1101 break; 1102 1103 // Use the ith equality to simplify other equalities. If any changes 1104 // are made to an equality constraint, it is normalized by GCD. 1105 for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) { 1106 if (atEq(k, j) != 0) { 1107 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true); 1108 equalities.normalizeRow(k); 1109 } 1110 } 1111 1112 // Use the ith equality to simplify inequalities. 1113 for (unsigned k = 0, t = getNumInequalities(); k < t; ++k) 1114 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false); 1115 1116 // Remove the ith equality and the found local variable. 1117 removeId(j); 1118 removeEquality(i); 1119 } 1120 } 1121 1122 void IntegerRelation::convertDimToLocal(unsigned dimStart, unsigned dimLimit) { 1123 assert(dimLimit <= getNumDimIds() && "Invalid dim pos range"); 1124 1125 if (dimStart >= dimLimit) 1126 return; 1127 1128 // Append new local variables corresponding to the dimensions to be converted. 1129 unsigned convertCount = dimLimit - dimStart; 1130 unsigned newLocalIdStart = getNumIds(); 1131 appendId(IdKind::Local, convertCount); 1132 1133 // Swap the new local variables with dimensions. 1134 for (unsigned i = 0; i < convertCount; ++i) 1135 swapId(i + dimStart, i + newLocalIdStart); 1136 1137 // Remove dimensions converted to local variables. 1138 removeIdRange(IdKind::SetDim, dimStart, dimLimit); 1139 } 1140 1141 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { 1142 assert(pos < getNumCols()); 1143 if (type == BoundType::EQ) { 1144 unsigned row = equalities.appendExtraRow(); 1145 equalities(row, pos) = 1; 1146 equalities(row, getNumCols() - 1) = -value; 1147 } else { 1148 unsigned row = inequalities.appendExtraRow(); 1149 inequalities(row, pos) = type == BoundType::LB ? 1 : -1; 1150 inequalities(row, getNumCols() - 1) = 1151 type == BoundType::LB ? -value : value; 1152 } 1153 } 1154 1155 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr, 1156 int64_t value) { 1157 assert(type != BoundType::EQ && "EQ not implemented"); 1158 assert(expr.size() == getNumCols()); 1159 unsigned row = inequalities.appendExtraRow(); 1160 for (unsigned i = 0, e = expr.size(); i < e; ++i) 1161 inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i]; 1162 inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += 1163 type == BoundType::LB ? -value : value; 1164 } 1165 1166 /// Adds a new local identifier as the floordiv of an affine function of other 1167 /// identifiers, the coefficients of which are provided in 'dividend' and with 1168 /// respect to a positive constant 'divisor'. Two constraints are added to the 1169 /// system to capture equivalence with the floordiv. 1170 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. 1171 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend, 1172 int64_t divisor) { 1173 assert(dividend.size() == getNumCols() && "incorrect dividend size"); 1174 assert(divisor > 0 && "positive divisor expected"); 1175 1176 appendId(IdKind::Local); 1177 1178 // Add two constraints for this new identifier 'q'. 1179 SmallVector<int64_t, 8> bound(dividend.size() + 1); 1180 1181 // dividend - q * divisor >= 0 1182 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, 1183 bound.begin()); 1184 bound.back() = dividend.back(); 1185 bound[getNumIds() - 1] = -divisor; 1186 addInequality(bound); 1187 1188 // -dividend +qdivisor * q + divisor - 1 >= 0 1189 std::transform(bound.begin(), bound.end(), bound.begin(), 1190 std::negate<int64_t>()); 1191 bound[bound.size() - 1] += divisor - 1; 1192 addInequality(bound); 1193 } 1194 1195 /// Finds an equality that equates the specified identifier to a constant. 1196 /// Returns the position of the equality row. If 'symbolic' is set to true, 1197 /// symbols are also treated like a constant, i.e., an affine function of the 1198 /// symbols is also treated like a constant. Returns -1 if such an equality 1199 /// could not be found. 1200 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, 1201 bool symbolic = false) { 1202 assert(pos < cst.getNumIds() && "invalid position"); 1203 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1204 int64_t v = cst.atEq(r, pos); 1205 if (v * v != 1) 1206 continue; 1207 unsigned c; 1208 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); 1209 // This checks for zeros in all positions other than 'pos' in [0, f) 1210 for (c = 0; c < f; c++) { 1211 if (c == pos) 1212 continue; 1213 if (cst.atEq(r, c) != 0) { 1214 // Dependent on another identifier. 1215 break; 1216 } 1217 } 1218 if (c == f) 1219 // Equality is free of other identifiers. 1220 return r; 1221 } 1222 return -1; 1223 } 1224 1225 LogicalResult IntegerRelation::constantFoldId(unsigned pos) { 1226 assert(pos < getNumIds() && "invalid position"); 1227 int rowIdx; 1228 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) 1229 return failure(); 1230 1231 // atEq(rowIdx, pos) is either -1 or 1. 1232 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); 1233 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); 1234 setAndEliminate(pos, constVal); 1235 return success(); 1236 } 1237 1238 void IntegerRelation::constantFoldIdRange(unsigned pos, unsigned num) { 1239 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { 1240 if (failed(constantFoldId(t))) 1241 t++; 1242 } 1243 } 1244 1245 /// Returns a non-negative constant bound on the extent (upper bound - lower 1246 /// bound) of the specified identifier if it is found to be a constant; returns 1247 /// None if it's not a constant. This methods treats symbolic identifiers 1248 /// specially, i.e., it looks for constant differences between affine 1249 /// expressions involving only the symbolic identifiers. See comments at 1250 /// function definition for example. 'lb', if provided, is set to the lower 1251 /// bound associated with the constant difference. Note that 'lb' is purely 1252 /// symbolic and thus will contain the coefficients of the symbolic identifiers 1253 /// and the constant coefficient. 1254 // Egs: 0 <= i <= 15, return 16. 1255 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) 1256 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. 1257 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = 1258 // ceil(s0 - 7 / 8) = floor(s0 / 8)). 1259 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize( 1260 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor, 1261 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos, 1262 unsigned *minUbPos) const { 1263 assert(pos < getNumDimIds() && "Invalid identifier position"); 1264 1265 // Find an equality for 'pos'^th identifier that equates it to some function 1266 // of the symbolic identifiers (+ constant). 1267 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); 1268 if (eqPos != -1) { 1269 auto eq = getEquality(eqPos); 1270 // If the equality involves a local var, punt for now. 1271 // TODO: this can be handled in the future by using the explicit 1272 // representation of the local vars. 1273 if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1, 1274 [](int64_t coeff) { return coeff == 0; })) 1275 return None; 1276 1277 // This identifier can only take a single value. 1278 if (lb) { 1279 // Set lb to that symbolic value. 1280 lb->resize(getNumSymbolIds() + 1); 1281 if (ub) 1282 ub->resize(getNumSymbolIds() + 1); 1283 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { 1284 int64_t v = atEq(eqPos, pos); 1285 // atEq(eqRow, pos) is either -1 or 1. 1286 assert(v * v == 1); 1287 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v 1288 : -atEq(eqPos, getNumDimIds() + c) / v; 1289 // Since this is an equality, ub = lb. 1290 if (ub) 1291 (*ub)[c] = (*lb)[c]; 1292 } 1293 assert(boundFloorDivisor && 1294 "both lb and divisor or none should be provided"); 1295 *boundFloorDivisor = 1; 1296 } 1297 if (minLbPos) 1298 *minLbPos = eqPos; 1299 if (minUbPos) 1300 *minUbPos = eqPos; 1301 return 1; 1302 } 1303 1304 // Check if the identifier appears at all in any of the inequalities. 1305 unsigned r, e; 1306 for (r = 0, e = getNumInequalities(); r < e; r++) { 1307 if (atIneq(r, pos) != 0) 1308 break; 1309 } 1310 if (r == e) 1311 // If it doesn't, there isn't a bound on it. 1312 return None; 1313 1314 // Positions of constraints that are lower/upper bounds on the variable. 1315 SmallVector<unsigned, 4> lbIndices, ubIndices; 1316 1317 // Gather all symbolic lower bounds and upper bounds of the variable, i.e., 1318 // the bounds can only involve symbolic (and local) identifiers. Since the 1319 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1320 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1321 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, 1322 /*eqIndices=*/nullptr, /*offset=*/0, 1323 /*num=*/getNumDimIds()); 1324 1325 Optional<int64_t> minDiff = None; 1326 unsigned minLbPosition = 0, minUbPosition = 0; 1327 for (auto ubPos : ubIndices) { 1328 for (auto lbPos : lbIndices) { 1329 // Look for a lower bound and an upper bound that only differ by a 1330 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. 1331 // For example, if ii is the pos^th variable, we are looking for 1332 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The 1333 // minimum among all such constant differences is kept since that's the 1334 // constant bounding the extent of the pos^th variable. 1335 unsigned j, e; 1336 for (j = 0, e = getNumCols() - 1; j < e; j++) 1337 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { 1338 break; 1339 } 1340 if (j < getNumCols() - 1) 1341 continue; 1342 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + 1343 atIneq(lbPos, getNumCols() - 1) + 1, 1344 atIneq(lbPos, pos)); 1345 // This bound is non-negative by definition. 1346 diff = std::max<int64_t>(diff, 0); 1347 if (minDiff == None || diff < minDiff) { 1348 minDiff = diff; 1349 minLbPosition = lbPos; 1350 minUbPosition = ubPos; 1351 } 1352 } 1353 } 1354 if (lb && minDiff.hasValue()) { 1355 // Set lb to the symbolic lower bound. 1356 lb->resize(getNumSymbolIds() + 1); 1357 if (ub) 1358 ub->resize(getNumSymbolIds() + 1); 1359 // The lower bound is the ceildiv of the lb constraint over the coefficient 1360 // of the variable at 'pos'. We express the ceildiv equivalently as a floor 1361 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + 1362 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). 1363 *boundFloorDivisor = atIneq(minLbPosition, pos); 1364 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos)); 1365 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { 1366 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); 1367 } 1368 if (ub) { 1369 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) 1370 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c); 1371 } 1372 // The lower bound leads to a ceildiv while the upper bound is a floordiv 1373 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val + 1374 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to 1375 // the constant term for the lower bound. 1376 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1; 1377 } 1378 if (minLbPos) 1379 *minLbPos = minLbPosition; 1380 if (minUbPos) 1381 *minUbPos = minUbPosition; 1382 return minDiff; 1383 } 1384 1385 template <bool isLower> 1386 Optional<int64_t> 1387 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { 1388 assert(pos < getNumIds() && "invalid position"); 1389 // Project to 'pos'. 1390 projectOut(0, pos); 1391 projectOut(1, getNumIds() - 1); 1392 // Check if there's an equality equating the '0'^th identifier to a constant. 1393 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); 1394 if (eqRowIdx != -1) 1395 // atEq(rowIdx, 0) is either -1 or 1. 1396 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); 1397 1398 // Check if the identifier appears at all in any of the inequalities. 1399 unsigned r, e; 1400 for (r = 0, e = getNumInequalities(); r < e; r++) { 1401 if (atIneq(r, 0) != 0) 1402 break; 1403 } 1404 if (r == e) 1405 // If it doesn't, there isn't a bound on it. 1406 return None; 1407 1408 Optional<int64_t> minOrMaxConst = None; 1409 1410 // Take the max across all const lower bounds (or min across all constant 1411 // upper bounds). 1412 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1413 if (isLower) { 1414 if (atIneq(r, 0) <= 0) 1415 // Not a lower bound. 1416 continue; 1417 } else if (atIneq(r, 0) >= 0) { 1418 // Not an upper bound. 1419 continue; 1420 } 1421 unsigned c, f; 1422 for (c = 0, f = getNumCols() - 1; c < f; c++) 1423 if (c != 0 && atIneq(r, c) != 0) 1424 break; 1425 if (c < getNumCols() - 1) 1426 // Not a constant bound. 1427 continue; 1428 1429 int64_t boundConst = 1430 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) 1431 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); 1432 if (isLower) { 1433 if (minOrMaxConst == None || boundConst > minOrMaxConst) 1434 minOrMaxConst = boundConst; 1435 } else { 1436 if (minOrMaxConst == None || boundConst < minOrMaxConst) 1437 minOrMaxConst = boundConst; 1438 } 1439 } 1440 return minOrMaxConst; 1441 } 1442 1443 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type, 1444 unsigned pos) const { 1445 if (type == BoundType::LB) 1446 return IntegerRelation(*this) 1447 .computeConstantLowerOrUpperBound</*isLower=*/true>(pos); 1448 if (type == BoundType::UB) 1449 return IntegerRelation(*this) 1450 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1451 1452 assert(type == BoundType::EQ && "expected EQ"); 1453 Optional<int64_t> lb = 1454 IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>( 1455 pos); 1456 Optional<int64_t> ub = 1457 IntegerRelation(*this) 1458 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1459 return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None; 1460 } 1461 1462 // A simple (naive and conservative) check for hyper-rectangularity. 1463 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const { 1464 assert(pos < getNumCols() - 1); 1465 // Check for two non-zero coefficients in the range [pos, pos + sum). 1466 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1467 unsigned sum = 0; 1468 for (unsigned c = pos; c < pos + num; c++) { 1469 if (atIneq(r, c) != 0) 1470 sum++; 1471 } 1472 if (sum > 1) 1473 return false; 1474 } 1475 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1476 unsigned sum = 0; 1477 for (unsigned c = pos; c < pos + num; c++) { 1478 if (atEq(r, c) != 0) 1479 sum++; 1480 } 1481 if (sum > 1) 1482 return false; 1483 } 1484 return true; 1485 } 1486 1487 /// Removes duplicate constraints, trivially true constraints, and constraints 1488 /// that can be detected as redundant as a result of differing only in their 1489 /// constant term part. A constraint of the form <non-negative constant> >= 0 is 1490 /// considered trivially true. 1491 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to 1492 // remove duplicates in place. 1493 void IntegerRelation::removeTrivialRedundancy() { 1494 gcdTightenInequalities(); 1495 normalizeConstraintsByGCD(); 1496 1497 // A map used to detect redundancy stemming from constraints that only differ 1498 // in their constant term. The value stored is <row position, const term> 1499 // for a given row. 1500 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>> 1501 rowsWithoutConstTerm; 1502 // To unique rows. 1503 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet; 1504 1505 // Check if constraint is of the form <non-negative-constant> >= 0. 1506 auto isTriviallyValid = [&](unsigned r) -> bool { 1507 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { 1508 if (atIneq(r, c) != 0) 1509 return false; 1510 } 1511 return atIneq(r, getNumCols() - 1) >= 0; 1512 }; 1513 1514 // Detect and mark redundant constraints. 1515 SmallVector<bool, 256> redunIneq(getNumInequalities(), false); 1516 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1517 int64_t *rowStart = &inequalities(r, 0); 1518 auto row = ArrayRef<int64_t>(rowStart, getNumCols()); 1519 if (isTriviallyValid(r) || !rowSet.insert(row).second) { 1520 redunIneq[r] = true; 1521 continue; 1522 } 1523 1524 // Among constraints that only differ in the constant term part, mark 1525 // everything other than the one with the smallest constant term redundant. 1526 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the 1527 // former two are redundant). 1528 int64_t constTerm = atIneq(r, getNumCols() - 1); 1529 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1); 1530 const auto &ret = 1531 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); 1532 if (!ret.second) { 1533 // Check if the other constraint has a higher constant term. 1534 auto &val = ret.first->second; 1535 if (val.second > constTerm) { 1536 // The stored row is redundant. Mark it so, and update with this one. 1537 redunIneq[val.first] = true; 1538 val = {r, constTerm}; 1539 } else { 1540 // The one stored makes this one redundant. 1541 redunIneq[r] = true; 1542 } 1543 } 1544 } 1545 1546 // Scan to get rid of all rows marked redundant, in-place. 1547 unsigned pos = 0; 1548 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) 1549 if (!redunIneq[r]) 1550 inequalities.copyRow(r, pos++); 1551 1552 inequalities.resizeVertically(pos); 1553 1554 // TODO: consider doing this for equalities as well, but probably not worth 1555 // the savings. 1556 } 1557 1558 #undef DEBUG_TYPE 1559 #define DEBUG_TYPE "fm" 1560 1561 /// Eliminates identifier at the specified position using Fourier-Motzkin 1562 /// variable elimination. This technique is exact for rational spaces but 1563 /// conservative (in "rare" cases) for integer spaces. The operation corresponds 1564 /// to a projection operation yielding the (convex) set of integer points 1565 /// contained in the rational shadow of the set. An emptiness test that relies 1566 /// on this method will guarantee emptiness, i.e., it disproves the existence of 1567 /// a solution if it says it's empty. 1568 /// If a non-null isResultIntegerExact is passed, it is set to true if the 1569 /// result is also integer exact. If it's set to false, the obtained solution 1570 /// *may* not be exact, i.e., it may contain integer points that do not have an 1571 /// integer pre-image in the original set. 1572 /// 1573 /// Eg: 1574 /// j >= 0, j <= i + 1 1575 /// i >= 0, i <= N + 1 1576 /// Eliminating i yields, 1577 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1 1578 /// 1579 /// If darkShadow = true, this method computes the dark shadow on elimination; 1580 /// the dark shadow is a convex integer subset of the exact integer shadow. A 1581 /// non-empty dark shadow proves the existence of an integer solution. The 1582 /// elimination in such a case could however be an under-approximation, and thus 1583 /// should not be used for scanning sets or used by itself for dependence 1584 /// checking. 1585 /// 1586 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. 1587 /// ^ 1588 /// | 1589 /// | * * * * o o 1590 /// i | * * o o o o 1591 /// | o * * * * * 1592 /// ---------------> 1593 /// j -> 1594 /// 1595 /// Eliminating i from this system (projecting on the j dimension): 1596 /// rational shadow / integer light shadow: 1 <= j <= 6 1597 /// dark shadow: 3 <= j <= 6 1598 /// exact integer shadow: j = 1 \union 3 <= j <= 6 1599 /// holes/splinters: j = 2 1600 /// 1601 /// darkShadow = false, isResultIntegerExact = nullptr are default values. 1602 // TODO: a slight modification to yield dark shadow version of FM (tightened), 1603 // which can prove the existence of a solution if there is one. 1604 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, 1605 bool *isResultIntegerExact) { 1606 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); 1607 LLVM_DEBUG(dump()); 1608 assert(pos < getNumIds() && "invalid position"); 1609 assert(hasConsistentState()); 1610 1611 // Check if this identifier can be eliminated through a substitution. 1612 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1613 if (atEq(r, pos) != 0) { 1614 // Use Gaussian elimination here (since we have an equality). 1615 LogicalResult ret = gaussianEliminateId(pos); 1616 (void)ret; 1617 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); 1618 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); 1619 LLVM_DEBUG(dump()); 1620 return; 1621 } 1622 } 1623 1624 // A fast linear time tightening. 1625 gcdTightenInequalities(); 1626 1627 // Check if the identifier appears at all in any of the inequalities. 1628 if (isColZero(pos)) { 1629 // If it doesn't appear, just remove the column and return. 1630 // TODO: refactor removeColumns to use it from here. 1631 removeId(pos); 1632 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1633 LLVM_DEBUG(dump()); 1634 return; 1635 } 1636 1637 // Positions of constraints that are lower bounds on the variable. 1638 SmallVector<unsigned, 4> lbIndices; 1639 // Positions of constraints that are lower bounds on the variable. 1640 SmallVector<unsigned, 4> ubIndices; 1641 // Positions of constraints that do not involve the variable. 1642 std::vector<unsigned> nbIndices; 1643 nbIndices.reserve(getNumInequalities()); 1644 1645 // Gather all lower bounds and upper bounds of the variable. Since the 1646 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1647 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1648 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1649 if (atIneq(r, pos) == 0) { 1650 // Id does not appear in bound. 1651 nbIndices.push_back(r); 1652 } else if (atIneq(r, pos) >= 1) { 1653 // Lower bound. 1654 lbIndices.push_back(r); 1655 } else { 1656 // Upper bound. 1657 ubIndices.push_back(r); 1658 } 1659 } 1660 1661 // Set the number of dimensions, symbols, locals in the resulting system. 1662 unsigned newNumDomain = 1663 getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1); 1664 unsigned newNumRange = 1665 getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1); 1666 unsigned newNumSymbols = 1667 getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1); 1668 unsigned newNumLocals = 1669 getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1); 1670 1671 /// Create the new system which has one identifier less. 1672 IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(), 1673 getNumEqualities(), getNumCols() - 1, newNumDomain, 1674 newNumRange, newNumSymbols, newNumLocals); 1675 1676 // This will be used to check if the elimination was integer exact. 1677 unsigned lcmProducts = 1; 1678 1679 // Let x be the variable we are eliminating. 1680 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note 1681 // that c_l, c_u >= 1) we have: 1682 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u 1683 // We thus generate a constraint: 1684 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. 1685 // Note if c_l = c_u = 1, all integer points captured by the resulting 1686 // constraint correspond to integer points in the original system (i.e., they 1687 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is 1688 // integer exact. 1689 for (auto ubPos : ubIndices) { 1690 for (auto lbPos : lbIndices) { 1691 SmallVector<int64_t, 4> ineq; 1692 ineq.reserve(newRel.getNumCols()); 1693 int64_t lbCoeff = atIneq(lbPos, pos); 1694 // Note that in the comments above, ubCoeff is the negation of the 1695 // coefficient in the canonical form as the view taken here is that of the 1696 // term being moved to the other size of '>='. 1697 int64_t ubCoeff = -atIneq(ubPos, pos); 1698 // TODO: refactor this loop to avoid all branches inside. 1699 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1700 if (l == pos) 1701 continue; 1702 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); 1703 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); 1704 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + 1705 atIneq(lbPos, l) * (lcm / lbCoeff)); 1706 lcmProducts *= lcm; 1707 } 1708 if (darkShadow) { 1709 // The dark shadow is a convex subset of the exact integer shadow. If 1710 // there is a point here, it proves the existence of a solution. 1711 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; 1712 } 1713 // TODO: we need to have a way to add inequalities in-place in 1714 // IntegerRelation instead of creating and copying over. 1715 newRel.addInequality(ineq); 1716 } 1717 } 1718 1719 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1) 1720 << "\n"); 1721 if (lcmProducts == 1 && isResultIntegerExact) 1722 *isResultIntegerExact = true; 1723 1724 // Copy over the constraints not involving this variable. 1725 for (auto nbPos : nbIndices) { 1726 SmallVector<int64_t, 4> ineq; 1727 ineq.reserve(getNumCols() - 1); 1728 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1729 if (l == pos) 1730 continue; 1731 ineq.push_back(atIneq(nbPos, l)); 1732 } 1733 newRel.addInequality(ineq); 1734 } 1735 1736 assert(newRel.getNumConstraints() == 1737 lbIndices.size() * ubIndices.size() + nbIndices.size()); 1738 1739 // Copy over the equalities. 1740 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1741 SmallVector<int64_t, 4> eq; 1742 eq.reserve(newRel.getNumCols()); 1743 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1744 if (l == pos) 1745 continue; 1746 eq.push_back(atEq(r, l)); 1747 } 1748 newRel.addEquality(eq); 1749 } 1750 1751 // GCD tightening and normalization allows detection of more trivially 1752 // redundant constraints. 1753 newRel.gcdTightenInequalities(); 1754 newRel.normalizeConstraintsByGCD(); 1755 newRel.removeTrivialRedundancy(); 1756 clearAndCopyFrom(newRel); 1757 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1758 LLVM_DEBUG(dump()); 1759 } 1760 1761 #undef DEBUG_TYPE 1762 #define DEBUG_TYPE "presburger" 1763 1764 void IntegerRelation::projectOut(unsigned pos, unsigned num) { 1765 if (num == 0) 1766 return; 1767 1768 // 'pos' can be at most getNumCols() - 2 if num > 0. 1769 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position"); 1770 assert(pos + num < getNumCols() && "invalid range"); 1771 1772 // Eliminate as many identifiers as possible using Gaussian elimination. 1773 unsigned currentPos = pos; 1774 unsigned numToEliminate = num; 1775 unsigned numGaussianEliminated = 0; 1776 1777 while (currentPos < getNumIds()) { 1778 unsigned curNumEliminated = 1779 gaussianEliminateIds(currentPos, currentPos + numToEliminate); 1780 ++currentPos; 1781 numToEliminate -= curNumEliminated + 1; 1782 numGaussianEliminated += curNumEliminated; 1783 } 1784 1785 // Eliminate the remaining using Fourier-Motzkin. 1786 for (unsigned i = 0; i < num - numGaussianEliminated; i++) { 1787 unsigned numToEliminate = num - numGaussianEliminated - i; 1788 fourierMotzkinEliminate( 1789 getBestIdToEliminate(*this, pos, pos + numToEliminate)); 1790 } 1791 1792 // Fast/trivial simplifications. 1793 gcdTightenInequalities(); 1794 // Normalize constraints after tightening since the latter impacts this, but 1795 // not the other way round. 1796 normalizeConstraintsByGCD(); 1797 } 1798 1799 namespace { 1800 1801 enum BoundCmpResult { Greater, Less, Equal, Unknown }; 1802 1803 /// Compares two affine bounds whose coefficients are provided in 'first' and 1804 /// 'second'. The last coefficient is the constant term. 1805 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 1806 assert(a.size() == b.size()); 1807 1808 // For the bounds to be comparable, their corresponding identifier 1809 // coefficients should be equal; the constant terms are then compared to 1810 // determine less/greater/equal. 1811 1812 if (!std::equal(a.begin(), a.end() - 1, b.begin())) 1813 return Unknown; 1814 1815 if (a.back() == b.back()) 1816 return Equal; 1817 1818 return a.back() < b.back() ? Less : Greater; 1819 } 1820 } // namespace 1821 1822 // Returns constraints that are common to both A & B. 1823 static void getCommonConstraints(const IntegerRelation &a, 1824 const IntegerRelation &b, IntegerRelation &c) { 1825 c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(), 1826 a.getNumSymbolIds(), a.getNumLocalIds()); 1827 // a naive O(n^2) check should be enough here given the input sizes. 1828 for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) { 1829 for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) { 1830 if (a.getInequality(r) == b.getInequality(s)) { 1831 c.addInequality(a.getInequality(r)); 1832 break; 1833 } 1834 } 1835 } 1836 for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) { 1837 for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) { 1838 if (a.getEquality(r) == b.getEquality(s)) { 1839 c.addEquality(a.getEquality(r)); 1840 break; 1841 } 1842 } 1843 } 1844 } 1845 1846 // Computes the bounding box with respect to 'other' by finding the min of the 1847 // lower bounds and the max of the upper bounds along each of the dimensions. 1848 LogicalResult 1849 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { 1850 assert(PresburgerLocalSpace::isEqual(otherCst) && "Spaces should match."); 1851 assert(getNumLocalIds() == 0 && "local ids not supported yet here"); 1852 1853 // Get the constraints common to both systems; these will be added as is to 1854 // the union. 1855 IntegerRelation commonCst; 1856 getCommonConstraints(*this, otherCst, commonCst); 1857 1858 std::vector<SmallVector<int64_t, 8>> boundingLbs; 1859 std::vector<SmallVector<int64_t, 8>> boundingUbs; 1860 boundingLbs.reserve(2 * getNumDimIds()); 1861 boundingUbs.reserve(2 * getNumDimIds()); 1862 1863 // To hold lower and upper bounds for each dimension. 1864 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; 1865 // To compute min of lower bounds and max of upper bounds for each dimension. 1866 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1); 1867 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1); 1868 // To compute final new lower and upper bounds for the union. 1869 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); 1870 1871 int64_t lbFloorDivisor, otherLbFloorDivisor; 1872 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 1873 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); 1874 if (!extent.hasValue()) 1875 // TODO: symbolic extents when necessary. 1876 // TODO: handle union if a dimension is unbounded. 1877 return failure(); 1878 1879 auto otherExtent = otherCst.getConstantBoundOnDimSize( 1880 d, &otherLb, &otherLbFloorDivisor, &otherUb); 1881 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor) 1882 // TODO: symbolic extents when necessary. 1883 return failure(); 1884 1885 assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); 1886 1887 auto res = compareBounds(lb, otherLb); 1888 // Identify min. 1889 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { 1890 minLb = lb; 1891 // Since the divisor is for a floordiv, we need to convert to ceildiv, 1892 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=> 1893 // div * i >= expr - div + 1. 1894 minLb.back() -= lbFloorDivisor - 1; 1895 } else if (res == BoundCmpResult::Greater) { 1896 minLb = otherLb; 1897 minLb.back() -= otherLbFloorDivisor - 1; 1898 } else { 1899 // Uncomparable - check for constant lower/upper bounds. 1900 auto constLb = getConstantBound(BoundType::LB, d); 1901 auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d); 1902 if (!constLb.hasValue() || !constOtherLb.hasValue()) 1903 return failure(); 1904 std::fill(minLb.begin(), minLb.end(), 0); 1905 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); 1906 } 1907 1908 // Do the same for ub's but max of upper bounds. Identify max. 1909 auto uRes = compareBounds(ub, otherUb); 1910 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { 1911 maxUb = ub; 1912 } else if (uRes == BoundCmpResult::Less) { 1913 maxUb = otherUb; 1914 } else { 1915 // Uncomparable - check for constant lower/upper bounds. 1916 auto constUb = getConstantBound(BoundType::UB, d); 1917 auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d); 1918 if (!constUb.hasValue() || !constOtherUb.hasValue()) 1919 return failure(); 1920 std::fill(maxUb.begin(), maxUb.end(), 0); 1921 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); 1922 } 1923 1924 std::fill(newLb.begin(), newLb.end(), 0); 1925 std::fill(newUb.begin(), newUb.end(), 0); 1926 1927 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, 1928 // and so it's the divisor for newLb and newUb as well. 1929 newLb[d] = lbFloorDivisor; 1930 newUb[d] = -lbFloorDivisor; 1931 // Copy over the symbolic part + constant term. 1932 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); 1933 std::transform(newLb.begin() + getNumDimIds(), newLb.end(), 1934 newLb.begin() + getNumDimIds(), std::negate<int64_t>()); 1935 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); 1936 1937 boundingLbs.push_back(newLb); 1938 boundingUbs.push_back(newUb); 1939 } 1940 1941 // Clear all constraints and add the lower/upper bounds for the bounding box. 1942 clearConstraints(); 1943 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 1944 addInequality(boundingLbs[d]); 1945 addInequality(boundingUbs[d]); 1946 } 1947 1948 // Add the constraints that were common to both systems. 1949 append(commonCst); 1950 removeTrivialRedundancy(); 1951 1952 // TODO: copy over pure symbolic constraints from this and 'other' over to the 1953 // union (since the above are just the union along dimensions); we shouldn't 1954 // be discarding any other constraints on the symbols. 1955 1956 return success(); 1957 } 1958 1959 bool IntegerRelation::isColZero(unsigned pos) const { 1960 unsigned rowPos; 1961 return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) && 1962 !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos); 1963 } 1964 1965 /// Find positions of inequalities and equalities that do not have a coefficient 1966 /// for [pos, pos + num) identifiers. 1967 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos, 1968 unsigned num, 1969 SmallVectorImpl<unsigned> &nbIneqIndices, 1970 SmallVectorImpl<unsigned> &nbEqIndices) { 1971 assert(pos < cst.getNumIds() && "invalid start position"); 1972 assert(pos + num <= cst.getNumIds() && "invalid limit"); 1973 1974 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 1975 // The bounds are to be independent of [offset, offset + num) columns. 1976 unsigned c; 1977 for (c = pos; c < pos + num; ++c) { 1978 if (cst.atIneq(r, c) != 0) 1979 break; 1980 } 1981 if (c == pos + num) 1982 nbIneqIndices.push_back(r); 1983 } 1984 1985 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1986 // The bounds are to be independent of [offset, offset + num) columns. 1987 unsigned c; 1988 for (c = pos; c < pos + num; ++c) { 1989 if (cst.atEq(r, c) != 0) 1990 break; 1991 } 1992 if (c == pos + num) 1993 nbEqIndices.push_back(r); 1994 } 1995 } 1996 1997 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) { 1998 assert(pos + num <= getNumIds() && "invalid range"); 1999 2000 // Remove constraints that are independent of these identifiers. 2001 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices; 2002 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices); 2003 2004 // Iterate in reverse so that indices don't have to be updated. 2005 // TODO: This method can be made more efficient (because removal of each 2006 // inequality leads to much shifting/copying in the underlying buffer). 2007 for (auto nbIndex : llvm::reverse(nbIneqIndices)) 2008 removeInequality(nbIndex); 2009 for (auto nbIndex : llvm::reverse(nbEqIndices)) 2010 removeEquality(nbIndex); 2011 } 2012 2013 void IntegerRelation::printSpace(raw_ostream &os) const { 2014 PresburgerLocalSpace::print(os); 2015 os << getNumConstraints() << " constraints\n"; 2016 } 2017 2018 void IntegerRelation::print(raw_ostream &os) const { 2019 assert(hasConsistentState()); 2020 printSpace(os); 2021 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 2022 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2023 os << atEq(i, j) << " "; 2024 } 2025 os << "= 0\n"; 2026 } 2027 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 2028 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2029 os << atIneq(i, j) << " "; 2030 } 2031 os << ">= 0\n"; 2032 } 2033 os << '\n'; 2034 } 2035 2036 void IntegerRelation::dump() const { print(llvm::errs()); } 2037 2038 unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) { 2039 assert((kind != IdKind::Domain || num == 0) && 2040 "Domain has to be zero in a set"); 2041 return IntegerRelation::insertId(kind, pos, num); 2042 } 2043