1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A class to represent an relation over integer tuples. A relation is 10 // represented as a constraint system over a space of tuples of integer valued 11 // variables supporting symbolic variables and existential quantification. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/Presburger/IntegerRelation.h" 16 #include "mlir/Analysis/Presburger/LinearTransform.h" 17 #include "mlir/Analysis/Presburger/PWMAFunction.h" 18 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 19 #include "mlir/Analysis/Presburger/Simplex.h" 20 #include "mlir/Analysis/Presburger/Utils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/DenseSet.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "presburger" 26 27 using namespace mlir; 28 using namespace presburger; 29 30 using llvm::SmallDenseMap; 31 using llvm::SmallDenseSet; 32 33 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const { 34 return std::make_unique<IntegerRelation>(*this); 35 } 36 37 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const { 38 return std::make_unique<IntegerPolyhedron>(*this); 39 } 40 41 void IntegerRelation::setSpace(const PresburgerSpace &oSpace) { 42 assert(space.getNumVars() == oSpace.getNumVars() && "invalid space!"); 43 space = oSpace; 44 } 45 46 void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) { 47 assert(oSpace.getNumLocalVars() == 0 && "no locals should be present!"); 48 assert(oSpace.getNumVars() <= getNumVars() && "invalid space!"); 49 unsigned newNumLocals = getNumVars() - oSpace.getNumVars(); 50 space = oSpace; 51 space.insertVar(VarKind::Local, 0, newNumLocals); 52 } 53 54 void IntegerRelation::append(const IntegerRelation &other) { 55 assert(space.isEqual(other.getSpace()) && "Spaces must be equal."); 56 57 inequalities.reserveRows(inequalities.getNumRows() + 58 other.getNumInequalities()); 59 equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities()); 60 61 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { 62 addInequality(other.getInequality(r)); 63 } 64 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { 65 addEquality(other.getEquality(r)); 66 } 67 } 68 69 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const { 70 IntegerRelation result = *this; 71 result.mergeLocalVars(other); 72 result.append(other); 73 return result; 74 } 75 76 bool IntegerRelation::isEqual(const IntegerRelation &other) const { 77 assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible."); 78 return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); 79 } 80 81 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const { 82 assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible."); 83 return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other)); 84 } 85 86 MaybeOptimum<SmallVector<Fraction, 8>> 87 IntegerRelation::findRationalLexMin() const { 88 assert(getNumSymbolVars() == 0 && "Symbols are not supported!"); 89 MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin = 90 LexSimplex(*this).findRationalLexMin(); 91 92 if (!maybeLexMin.isBounded()) 93 return maybeLexMin; 94 95 // The Simplex returns the lexmin over all the variables including locals. But 96 // locals are not actually part of the space and should not be returned in the 97 // result. Since the locals are placed last in the list of variables, they 98 // will be minimized last in the lexmin. So simply truncating out the locals 99 // from the end of the answer gives the desired lexmin over the dimensions. 100 assert(maybeLexMin->size() == getNumVars() && 101 "Incorrect number of vars in lexMin!"); 102 maybeLexMin->resize(getNumDimAndSymbolVars()); 103 return maybeLexMin; 104 } 105 106 MaybeOptimum<SmallVector<int64_t, 8>> 107 IntegerRelation::findIntegerLexMin() const { 108 assert(getNumSymbolVars() == 0 && "Symbols are not supported!"); 109 MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin = 110 LexSimplex(*this).findIntegerLexMin(); 111 112 if (!maybeLexMin.isBounded()) 113 return maybeLexMin.getKind(); 114 115 // The Simplex returns the lexmin over all the variables including locals. But 116 // locals are not actually part of the space and should not be returned in the 117 // result. Since the locals are placed last in the list of variables, they 118 // will be minimized last in the lexmin. So simply truncating out the locals 119 // from the end of the answer gives the desired lexmin over the dimensions. 120 assert(maybeLexMin->size() == getNumVars() && 121 "Incorrect number of vars in lexMin!"); 122 maybeLexMin->resize(getNumDimAndSymbolVars()); 123 return maybeLexMin; 124 } 125 126 static bool rangeIsZero(ArrayRef<int64_t> range) { 127 return llvm::all_of(range, [](int64_t x) { return x == 0; }); 128 } 129 130 static void removeConstraintsInvolvingVarRange(IntegerRelation &poly, 131 unsigned begin, unsigned count) { 132 // We loop until i > 0 and index into i - 1 to avoid sign issues. 133 // 134 // We iterate backwards so that whether we remove constraint i - 1 or not, the 135 // next constraint to be tested is always i - 2. 136 for (unsigned i = poly.getNumEqualities(); i > 0; i--) 137 if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count))) 138 poly.removeEquality(i - 1); 139 for (unsigned i = poly.getNumInequalities(); i > 0; i--) 140 if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count))) 141 poly.removeInequality(i - 1); 142 } 143 144 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const { 145 return {getSpace(), getNumInequalities(), getNumEqualities()}; 146 } 147 148 void IntegerRelation::truncateVarKind(VarKind kind, unsigned num) { 149 unsigned curNum = getNumVarKind(kind); 150 assert(num <= curNum && "Can't truncate to more vars!"); 151 removeVarRange(kind, num, curNum); 152 } 153 154 void IntegerRelation::truncateVarKind(VarKind kind, 155 const CountsSnapshot &counts) { 156 truncateVarKind(kind, counts.getSpace().getNumVarKind(kind)); 157 } 158 159 void IntegerRelation::truncate(const CountsSnapshot &counts) { 160 truncateVarKind(VarKind::Domain, counts); 161 truncateVarKind(VarKind::Range, counts); 162 truncateVarKind(VarKind::Symbol, counts); 163 truncateVarKind(VarKind::Local, counts); 164 removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); 165 removeEqualityRange(counts.getNumEqs(), getNumEqualities()); 166 } 167 168 PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const { 169 // If there are no locals, we're done. 170 if (getNumLocalVars() == 0) 171 return PresburgerRelation(*this); 172 173 // Move all the non-div locals to the end, as the current API to 174 // SymbolicLexMin requires these to form a contiguous range. 175 // 176 // Take a copy so we can perform mutations. 177 IntegerRelation copy = *this; 178 std::vector<MaybeLocalRepr> reprs; 179 copy.getLocalReprs(reprs); 180 181 // Iterate through all the locals. The last `numNonDivLocals` are the locals 182 // that have been scanned already and do not have division representations. 183 unsigned numNonDivLocals = 0; 184 unsigned offset = copy.getVarKindOffset(VarKind::Local); 185 for (unsigned i = 0, e = copy.getNumLocalVars(); i < e - numNonDivLocals;) { 186 if (!reprs[i]) { 187 // Whenever we come across a local that does not have a division 188 // representation, we swap it to the `numNonDivLocals`-th last position 189 // and increment `numNonDivLocal`s. `reprs` also needs to be swapped. 190 copy.swapVar(offset + i, offset + e - numNonDivLocals - 1); 191 std::swap(reprs[i], reprs[e - numNonDivLocals - 1]); 192 ++numNonDivLocals; 193 continue; 194 } 195 ++i; 196 } 197 198 // If there are no non-div locals, we're done. 199 if (numNonDivLocals == 0) 200 return PresburgerRelation(*this); 201 202 // We computeSymbolicIntegerLexMin by considering the non-div locals as 203 // "non-symbols" and considering everything else as "symbols". This will 204 // compute a function mapping assignments to "symbols" to the 205 // lexicographically minimal valid assignment of "non-symbols", when a 206 // satisfying assignment exists. It separately returns the set of assignments 207 // to the "symbols" such that a satisfying assignment to the "non-symbols" 208 // exists but the lexmin is unbounded. We basically want to find the set of 209 // values of the "symbols" such that an assignment to the "non-symbols" 210 // exists, which is the union of the domain of the returned lexmin function 211 // and the returned set of assignments to the "symbols" that makes the lexmin 212 // unbounded. 213 SymbolicLexMin lexminResult = 214 SymbolicLexSimplex(copy, /*symbolOffset*/ 0, 215 IntegerPolyhedron(PresburgerSpace::getSetSpace( 216 /*numDims=*/copy.getNumVars() - numNonDivLocals))) 217 .computeSymbolicIntegerLexMin(); 218 PresburgerRelation result = 219 lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain); 220 221 // The result set might lie in the wrong space -- all its ids are dims. 222 // Set it to the desired space and return. 223 PresburgerSpace space = getSpace(); 224 space.removeVarRange(VarKind::Local, 0, getNumLocalVars()); 225 result.setSpace(space); 226 return result; 227 } 228 229 SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const { 230 // Symbol and Domain vars will be used as symbols for symbolic lexmin. 231 // In other words, for every value of the symbols and domain, return the 232 // lexmin value of the (range, locals). 233 llvm::SmallBitVector isSymbol(getNumVars(), false); 234 isSymbol.set(getVarKindOffset(VarKind::Symbol), 235 getVarKindEnd(VarKind::Symbol)); 236 isSymbol.set(getVarKindOffset(VarKind::Domain), 237 getVarKindEnd(VarKind::Domain)); 238 // Compute the symbolic lexmin of the dims and locals, with the symbols being 239 // the actual symbols of this set. 240 SymbolicLexMin result = 241 SymbolicLexSimplex(*this, 242 IntegerPolyhedron(PresburgerSpace::getSetSpace( 243 /*numDims=*/getNumDomainVars(), 244 /*numSymbols=*/getNumSymbolVars())), 245 isSymbol) 246 .computeSymbolicIntegerLexMin(); 247 248 // We want to return only the lexmin over the dims, so strip the locals from 249 // the computed lexmin. 250 result.lexmin.truncateOutput(result.lexmin.getNumOutputs() - 251 getNumLocalVars()); 252 return result; 253 } 254 255 PresburgerRelation 256 IntegerRelation::subtract(const PresburgerRelation &set) const { 257 return PresburgerRelation(*this).subtract(set); 258 } 259 260 unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) { 261 assert(pos <= getNumVarKind(kind)); 262 263 unsigned insertPos = space.insertVar(kind, pos, num); 264 inequalities.insertColumns(insertPos, num); 265 equalities.insertColumns(insertPos, num); 266 return insertPos; 267 } 268 269 unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) { 270 unsigned pos = getNumVarKind(kind); 271 return insertVar(kind, pos, num); 272 } 273 274 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) { 275 assert(eq.size() == getNumCols()); 276 unsigned row = equalities.appendExtraRow(); 277 for (unsigned i = 0, e = eq.size(); i < e; ++i) 278 equalities(row, i) = eq[i]; 279 } 280 281 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) { 282 assert(inEq.size() == getNumCols()); 283 unsigned row = inequalities.appendExtraRow(); 284 for (unsigned i = 0, e = inEq.size(); i < e; ++i) 285 inequalities(row, i) = inEq[i]; 286 } 287 288 void IntegerRelation::removeVar(VarKind kind, unsigned pos) { 289 removeVarRange(kind, pos, pos + 1); 290 } 291 292 void IntegerRelation::removeVar(unsigned pos) { removeVarRange(pos, pos + 1); } 293 294 void IntegerRelation::removeVarRange(VarKind kind, unsigned varStart, 295 unsigned varLimit) { 296 assert(varLimit <= getNumVarKind(kind)); 297 298 if (varStart >= varLimit) 299 return; 300 301 // Remove eliminated variables from the constraints. 302 unsigned offset = getVarKindOffset(kind); 303 equalities.removeColumns(offset + varStart, varLimit - varStart); 304 inequalities.removeColumns(offset + varStart, varLimit - varStart); 305 306 // Remove eliminated variables from the space. 307 space.removeVarRange(kind, varStart, varLimit); 308 } 309 310 void IntegerRelation::removeVarRange(unsigned varStart, unsigned varLimit) { 311 assert(varLimit <= getNumVars()); 312 313 if (varStart >= varLimit) 314 return; 315 316 // Helper function to remove vars of the specified kind in the given range 317 // [start, limit), The range is absolute (i.e. it is not relative to the kind 318 // of variable). Also updates `limit` to reflect the deleted variables. 319 auto removeVarKindInRange = [this](VarKind kind, unsigned &start, 320 unsigned &limit) { 321 if (start >= limit) 322 return; 323 324 unsigned offset = getVarKindOffset(kind); 325 unsigned num = getNumVarKind(kind); 326 327 // Get `start`, `limit` relative to the specified kind. 328 unsigned relativeStart = 329 start <= offset ? 0 : std::min(num, start - offset); 330 unsigned relativeLimit = 331 limit <= offset ? 0 : std::min(num, limit - offset); 332 333 // Remove vars of the specified kind in the relative range. 334 removeVarRange(kind, relativeStart, relativeLimit); 335 336 // Update `limit` to reflect deleted variables. 337 // `start` does not need to be updated because any variables that are 338 // deleted are after position `start`. 339 limit -= relativeLimit - relativeStart; 340 }; 341 342 removeVarKindInRange(VarKind::Domain, varStart, varLimit); 343 removeVarKindInRange(VarKind::Range, varStart, varLimit); 344 removeVarKindInRange(VarKind::Symbol, varStart, varLimit); 345 removeVarKindInRange(VarKind::Local, varStart, varLimit); 346 } 347 348 void IntegerRelation::removeEquality(unsigned pos) { 349 equalities.removeRow(pos); 350 } 351 352 void IntegerRelation::removeInequality(unsigned pos) { 353 inequalities.removeRow(pos); 354 } 355 356 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) { 357 if (start >= end) 358 return; 359 equalities.removeRows(start, end - start); 360 } 361 362 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) { 363 if (start >= end) 364 return; 365 inequalities.removeRows(start, end - start); 366 } 367 368 void IntegerRelation::swapVar(unsigned posA, unsigned posB) { 369 assert(posA < getNumVars() && "invalid position A"); 370 assert(posB < getNumVars() && "invalid position B"); 371 372 if (posA == posB) 373 return; 374 375 inequalities.swapColumns(posA, posB); 376 equalities.swapColumns(posA, posB); 377 } 378 379 void IntegerRelation::clearConstraints() { 380 equalities.resizeVertically(0); 381 inequalities.resizeVertically(0); 382 } 383 384 /// Gather all lower and upper bounds of the variable at `pos`, and 385 /// optionally any equalities on it. In addition, the bounds are to be 386 /// independent of variables in position range [`offset`, `offset` + `num`). 387 void IntegerRelation::getLowerAndUpperBoundIndices( 388 unsigned pos, SmallVectorImpl<unsigned> *lbIndices, 389 SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices, 390 unsigned offset, unsigned num) const { 391 assert(pos < getNumVars() && "invalid position"); 392 assert(offset + num < getNumCols() && "invalid range"); 393 394 // Checks for a constraint that has a non-zero coeff for the variables in 395 // the position range [offset, offset + num) while ignoring `pos`. 396 auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) { 397 unsigned c, f; 398 auto cst = isEq ? getEquality(r) : getInequality(r); 399 for (c = offset, f = offset + num; c < f; ++c) { 400 if (c == pos) 401 continue; 402 if (cst[c] != 0) 403 break; 404 } 405 return c < f; 406 }; 407 408 // Gather all lower bounds and upper bounds of the variable. Since the 409 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 410 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 411 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 412 // The bounds are to be independent of [offset, offset + num) columns. 413 if (containsConstraintDependentOnRange(r, /*isEq=*/false)) 414 continue; 415 if (atIneq(r, pos) >= 1) { 416 // Lower bound. 417 lbIndices->push_back(r); 418 } else if (atIneq(r, pos) <= -1) { 419 // Upper bound. 420 ubIndices->push_back(r); 421 } 422 } 423 424 // An equality is both a lower and upper bound. Record any equalities 425 // involving the pos^th variable. 426 if (!eqIndices) 427 return; 428 429 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 430 if (atEq(r, pos) == 0) 431 continue; 432 if (containsConstraintDependentOnRange(r, /*isEq=*/true)) 433 continue; 434 eqIndices->push_back(r); 435 } 436 } 437 438 bool IntegerRelation::hasConsistentState() const { 439 if (!inequalities.hasConsistentState()) 440 return false; 441 if (!equalities.hasConsistentState()) 442 return false; 443 return true; 444 } 445 446 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) { 447 if (values.empty()) 448 return; 449 assert(pos + values.size() <= getNumVars() && 450 "invalid position or too many values"); 451 // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the 452 // constant term and removing the var x_j. We do this for all the vars 453 // pos, pos + 1, ... pos + values.size() - 1. 454 unsigned constantColPos = getNumCols() - 1; 455 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 456 inequalities.addToColumn(i + pos, constantColPos, values[i]); 457 for (unsigned i = 0, numVals = values.size(); i < numVals; ++i) 458 equalities.addToColumn(i + pos, constantColPos, values[i]); 459 removeVarRange(pos, pos + values.size()); 460 } 461 462 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) { 463 *this = other; 464 } 465 466 // Searches for a constraint with a non-zero coefficient at `colIdx` in 467 // equality (isEq=true) or inequality (isEq=false) constraints. 468 // Returns true and sets row found in search in `rowIdx`, false otherwise. 469 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, 470 unsigned *rowIdx) const { 471 assert(colIdx < getNumCols() && "position out of bounds"); 472 auto at = [&](unsigned rowIdx) -> int64_t { 473 return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); 474 }; 475 unsigned e = isEq ? getNumEqualities() : getNumInequalities(); 476 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { 477 if (at(*rowIdx) != 0) { 478 return true; 479 } 480 } 481 return false; 482 } 483 484 void IntegerRelation::normalizeConstraintsByGCD() { 485 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 486 equalities.normalizeRow(i); 487 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 488 inequalities.normalizeRow(i); 489 } 490 491 bool IntegerRelation::hasInvalidConstraint() const { 492 assert(hasConsistentState()); 493 auto check = [&](bool isEq) -> bool { 494 unsigned numCols = getNumCols(); 495 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); 496 for (unsigned i = 0, e = numRows; i < e; ++i) { 497 unsigned j; 498 for (j = 0; j < numCols - 1; ++j) { 499 int64_t v = isEq ? atEq(i, j) : atIneq(i, j); 500 // Skip rows with non-zero variable coefficients. 501 if (v != 0) 502 break; 503 } 504 if (j < numCols - 1) { 505 continue; 506 } 507 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. 508 // Example invalid constraints include: '1 == 0' or '-1 >= 0' 509 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); 510 if ((isEq && v != 0) || (!isEq && v < 0)) { 511 return true; 512 } 513 } 514 return false; 515 }; 516 if (check(/*isEq=*/true)) 517 return true; 518 return check(/*isEq=*/false); 519 } 520 521 /// Eliminate variable from constraint at `rowIdx` based on coefficient at 522 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be 523 /// updated as they have already been eliminated. 524 static void eliminateFromConstraint(IntegerRelation *constraints, 525 unsigned rowIdx, unsigned pivotRow, 526 unsigned pivotCol, unsigned elimColStart, 527 bool isEq) { 528 // Skip if equality 'rowIdx' if same as 'pivotRow'. 529 if (isEq && rowIdx == pivotRow) 530 return; 531 auto at = [&](unsigned i, unsigned j) -> int64_t { 532 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); 533 }; 534 int64_t leadCoeff = at(rowIdx, pivotCol); 535 // Skip if leading coefficient at 'rowIdx' is already zero. 536 if (leadCoeff == 0) 537 return; 538 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); 539 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; 540 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); 541 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); 542 int64_t rowMultiplier = lcm / std::abs(leadCoeff); 543 544 unsigned numCols = constraints->getNumCols(); 545 for (unsigned j = 0; j < numCols; ++j) { 546 // Skip updating column 'j' if it was just eliminated. 547 if (j >= elimColStart && j < pivotCol) 548 continue; 549 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + 550 rowMultiplier * at(rowIdx, j); 551 isEq ? constraints->atEq(rowIdx, j) = v 552 : constraints->atIneq(rowIdx, j) = v; 553 } 554 } 555 556 /// Returns the position of the variable that has the minimum <number of lower 557 /// bounds> times <number of upper bounds> from the specified range of 558 /// variables [start, end). It is often best to eliminate in the increasing 559 /// order of these counts when doing Fourier-Motzkin elimination since FM adds 560 /// that many new constraints. 561 static unsigned getBestVarToEliminate(const IntegerRelation &cst, 562 unsigned start, unsigned end) { 563 assert(start < cst.getNumVars() && end < cst.getNumVars() + 1); 564 565 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { 566 unsigned numLb = 0; 567 unsigned numUb = 0; 568 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 569 if (cst.atIneq(r, pos) > 0) { 570 ++numLb; 571 } else if (cst.atIneq(r, pos) < 0) { 572 ++numUb; 573 } 574 } 575 return numLb * numUb; 576 }; 577 578 unsigned minLoc = start; 579 unsigned min = getProductOfNumLowerUpperBounds(start); 580 for (unsigned c = start + 1; c < end; c++) { 581 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); 582 if (numLbUbProduct < min) { 583 min = numLbUbProduct; 584 minLoc = c; 585 } 586 } 587 return minLoc; 588 } 589 590 // Checks for emptiness of the set by eliminating variables successively and 591 // using the GCD test (on all equality constraints) and checking for trivially 592 // invalid constraints. Returns 'true' if the constraint system is found to be 593 // empty; false otherwise. 594 bool IntegerRelation::isEmpty() const { 595 if (isEmptyByGCDTest() || hasInvalidConstraint()) 596 return true; 597 598 IntegerRelation tmpCst(*this); 599 600 // First, eliminate as many local variables as possible using equalities. 601 tmpCst.removeRedundantLocalVars(); 602 if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) 603 return true; 604 605 // Eliminate as many variables as possible using Gaussian elimination. 606 unsigned currentPos = 0; 607 while (currentPos < tmpCst.getNumVars()) { 608 tmpCst.gaussianEliminateVars(currentPos, tmpCst.getNumVars()); 609 ++currentPos; 610 // We check emptiness through trivial checks after eliminating each ID to 611 // detect emptiness early. Since the checks isEmptyByGCDTest() and 612 // hasInvalidConstraint() are linear time and single sweep on the constraint 613 // buffer, this appears reasonable - but can optimize in the future. 614 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) 615 return true; 616 } 617 618 // Eliminate the remaining using FM. 619 for (unsigned i = 0, e = tmpCst.getNumVars(); i < e; i++) { 620 tmpCst.fourierMotzkinEliminate( 621 getBestVarToEliminate(tmpCst, 0, tmpCst.getNumVars())); 622 // Check for a constraint explosion. This rarely happens in practice, but 623 // this check exists as a safeguard against improperly constructed 624 // constraint systems or artificially created arbitrarily complex systems 625 // that aren't the intended use case for IntegerRelation. This is 626 // needed since FM has a worst case exponential complexity in theory. 627 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumVars()) { 628 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n"); 629 return false; 630 } 631 632 // FM wouldn't have modified the equalities in any way. So no need to again 633 // run GCD test. Check for trivial invalid constraints. 634 if (tmpCst.hasInvalidConstraint()) 635 return true; 636 } 637 return false; 638 } 639 640 // Runs the GCD test on all equality constraints. Returns 'true' if this test 641 // fails on any equality. Returns 'false' otherwise. 642 // This test can be used to disprove the existence of a solution. If it returns 643 // true, no integer solution to the equality constraints can exist. 644 // 645 // GCD test definition: 646 // 647 // The equality constraint: 648 // 649 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 650 // 651 // has an integer solution iff: 652 // 653 // GCD of c_1, c_2, ..., c_n divides c_0. 654 // 655 bool IntegerRelation::isEmptyByGCDTest() const { 656 assert(hasConsistentState()); 657 unsigned numCols = getNumCols(); 658 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 659 uint64_t gcd = std::abs(atEq(i, 0)); 660 for (unsigned j = 1; j < numCols - 1; ++j) { 661 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); 662 } 663 int64_t v = std::abs(atEq(i, numCols - 1)); 664 if (gcd > 0 && (v % gcd != 0)) { 665 return true; 666 } 667 } 668 return false; 669 } 670 671 // Returns a matrix where each row is a vector along which the polytope is 672 // bounded. The span of the returned vectors is guaranteed to contain all 673 // such vectors. The returned vectors are NOT guaranteed to be linearly 674 // independent. This function should not be called on empty sets. 675 // 676 // It is sufficient to check the perpendiculars of the constraints, as the set 677 // of perpendiculars which are bounded must span all bounded directions. 678 Matrix IntegerRelation::getBoundedDirections() const { 679 // Note that it is necessary to add the equalities too (which the constructor 680 // does) even though we don't need to check if they are bounded; whether an 681 // inequality is bounded or not depends on what other constraints, including 682 // equalities, are present. 683 Simplex simplex(*this); 684 685 assert(!simplex.isEmpty() && "It is not meaningful to ask whether a " 686 "direction is bounded in an empty set."); 687 688 SmallVector<unsigned, 8> boundedIneqs; 689 // The constructor adds the inequalities to the simplex first, so this 690 // processes all the inequalities. 691 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 692 if (simplex.isBoundedAlongConstraint(i)) 693 boundedIneqs.push_back(i); 694 } 695 696 // The direction vector is given by the coefficients and does not include the 697 // constant term, so the matrix has one fewer column. 698 unsigned dirsNumCols = getNumCols() - 1; 699 Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols); 700 701 // Copy the bounded inequalities. 702 unsigned row = 0; 703 for (unsigned i : boundedIneqs) { 704 for (unsigned col = 0; col < dirsNumCols; ++col) 705 dirs(row, col) = atIneq(i, col); 706 ++row; 707 } 708 709 // Copy the equalities. All the equalities' perpendiculars are bounded. 710 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 711 for (unsigned col = 0; col < dirsNumCols; ++col) 712 dirs(row, col) = atEq(i, col); 713 ++row; 714 } 715 716 return dirs; 717 } 718 719 bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); } 720 721 /// Let this set be S. If S is bounded then we directly call into the GBR 722 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e., 723 /// vectors v such that S extends to infinity along v or -v. In this case we 724 /// use an algorithm described in the integer set library (isl) manual and used 725 /// by the isl_set_sample function in that library. The algorithm is: 726 /// 727 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all 728 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the 729 /// dimensions. 730 /// 731 /// 2) Construct a set B by removing all constraints that involve 732 /// the unbounded dimensions and then deleting the unbounded dimensions. Note 733 /// that B is a Bounded set. 734 /// 735 /// 3) Try to obtain a sample from B using the GBR sampling 736 /// algorithm. If no sample is found, return that S is empty. 737 /// 738 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set 739 /// C. C is a full-dimensional Cone and always contains a sample. 740 /// 741 /// 5) Obtain an integer sample from C. 742 /// 743 /// 6) Return T*v, where v is the concatenation of the samples from B and C. 744 /// 745 /// The following is a sketch of a proof that 746 /// a) If the algorithm returns empty, then S is empty. 747 /// b) If the algorithm returns a sample, it is a valid sample in S. 748 /// 749 /// The algorithm returns empty only if B is empty, in which case S*T is 750 /// certainly empty since B was obtained by removing constraints and then 751 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector 752 /// v is in S*T iff T*v is in S. So in this case, since 753 /// S*T is empty, S is empty too. 754 /// 755 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the 756 /// constraints of S*T that did not involve unbounded dimensions are satisfied 757 /// by this substitution. All dimensions in the linear span of the dimensions 758 /// outside the prefix are unbounded in S*T (step 1). Substituting values for 759 /// the bounded dimensions cannot make these dimensions bounded, and these are 760 /// the only remaining dimensions in C, so C is unbounded along every vector (in 761 /// the positive or negative direction, or both). C is hence a full-dimensional 762 /// cone and therefore always contains an integer point. 763 /// 764 /// Concatenating the samples from B and C gives a sample v in S*T, so the 765 /// returned sample T*v is a sample in S. 766 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const { 767 // First, try the GCD test heuristic. 768 if (isEmptyByGCDTest()) 769 return {}; 770 771 Simplex simplex(*this); 772 if (simplex.isEmpty()) 773 return {}; 774 775 // For a bounded set, we directly call into the GBR sampling algorithm. 776 if (!simplex.isUnbounded()) 777 return simplex.findIntegerSample(); 778 779 // The set is unbounded. We cannot directly use the GBR algorithm. 780 // 781 // m is a matrix containing, in each row, a vector in which S is 782 // bounded, such that the linear span of all these dimensions contains all 783 // bounded dimensions in S. 784 Matrix m = getBoundedDirections(); 785 // In column echelon form, each row of m occupies only the first rank(m) 786 // columns and has zeros on the other columns. The transform T that brings S 787 // to column echelon form is unimodular as well, so this is a suitable 788 // transform to use in step 1 of the algorithm. 789 std::pair<unsigned, LinearTransform> result = 790 LinearTransform::makeTransformToColumnEchelon(std::move(m)); 791 const LinearTransform &transform = result.second; 792 // 1) Apply T to S to obtain S*T. 793 IntegerRelation transformedSet = transform.applyTo(*this); 794 795 // 2) Remove the unbounded dimensions and constraints involving them to 796 // obtain a bounded set. 797 IntegerRelation boundedSet(transformedSet); 798 unsigned numBoundedDims = result.first; 799 unsigned numUnboundedDims = getNumVars() - numBoundedDims; 800 removeConstraintsInvolvingVarRange(boundedSet, numBoundedDims, 801 numUnboundedDims); 802 boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars()); 803 804 // 3) Try to obtain a sample from the bounded set. 805 Optional<SmallVector<int64_t, 8>> boundedSample = 806 Simplex(boundedSet).findIntegerSample(); 807 if (!boundedSample) 808 return {}; 809 assert(boundedSet.containsPoint(*boundedSample) && 810 "Simplex returned an invalid sample!"); 811 812 // 4) Substitute the values of the bounded dimensions into S*T to obtain a 813 // full-dimensional cone, which necessarily contains an integer sample. 814 transformedSet.setAndEliminate(0, *boundedSample); 815 IntegerRelation &cone = transformedSet; 816 817 // 5) Obtain an integer sample from the cone. 818 // 819 // We shrink the cone such that for any rational point in the shrunken cone, 820 // rounding up each of the point's coordinates produces a point that still 821 // lies in the original cone. 822 // 823 // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i. 824 // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the 825 // shrunken cone will have the inequality tightened by some amount s, such 826 // that if x satisfies the shrunken cone's tightened inequality, then x + e 827 // satisfies the original inequality, i.e., 828 // 829 // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0 830 // 831 // for any e_i values in [0, 1). In fact, we will handle the slightly more 832 // general case where e_i can be in [0, 1]. For example, consider the 833 // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low 834 // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS 835 // is minimized when we add 1 to the x_i with negative coefficient a_i and 836 // keep the other x_i the same. In the example, we would get x = (3, 1, 1), 837 // changing the value of the LHS by -3 + -7 = -10. 838 // 839 // In general, the value of the LHS can change by at most the sum of the 840 // negative a_i, so we accomodate this by shifting the inequality by this 841 // amount for the shrunken cone. 842 for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) { 843 for (unsigned j = 0; j < cone.getNumVars(); ++j) { 844 int64_t coeff = cone.atIneq(i, j); 845 if (coeff < 0) 846 cone.atIneq(i, cone.getNumVars()) += coeff; 847 } 848 } 849 850 // Obtain an integer sample in the cone by rounding up a rational point from 851 // the shrunken cone. Shrinking the cone amounts to shifting its apex 852 // "inwards" without changing its "shape"; the shrunken cone is still a 853 // full-dimensional cone and is hence non-empty. 854 Simplex shrunkenConeSimplex(cone); 855 assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!"); 856 857 // The sample will always exist since the shrunken cone is non-empty. 858 SmallVector<Fraction, 8> shrunkenConeSample = 859 *shrunkenConeSimplex.getRationalSample(); 860 861 SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil)); 862 863 // 6) Return transform * concat(boundedSample, coneSample). 864 SmallVector<int64_t, 8> &sample = *boundedSample; 865 sample.append(coneSample.begin(), coneSample.end()); 866 return transform.postMultiplyWithColumn(sample); 867 } 868 869 /// Helper to evaluate an affine expression at a point. 870 /// The expression is a list of coefficients for the dimensions followed by the 871 /// constant term. 872 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) { 873 assert(expr.size() == 1 + point.size() && 874 "Dimensionalities of point and expression don't match!"); 875 int64_t value = expr.back(); 876 for (unsigned i = 0; i < point.size(); ++i) 877 value += expr[i] * point[i]; 878 return value; 879 } 880 881 /// A point satisfies an equality iff the value of the equality at the 882 /// expression is zero, and it satisfies an inequality iff the value of the 883 /// inequality at that point is non-negative. 884 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const { 885 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 886 if (valueAt(getEquality(i), point) != 0) 887 return false; 888 } 889 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 890 if (valueAt(getInequality(i), point) < 0) 891 return false; 892 } 893 return true; 894 } 895 896 /// Just substitute the values given and check if an integer sample exists for 897 /// the local vars. 898 /// 899 /// TODO: this could be made more efficient by handling divisions separately. 900 /// Instead of finding an integer sample over all the locals, we can first 901 /// compute the values of the locals that have division representations and 902 /// only use the integer emptiness check for the locals that don't have this. 903 /// Handling this correctly requires ordering the divs, though. 904 Optional<SmallVector<int64_t, 8>> 905 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const { 906 assert(point.size() == getNumVars() - getNumLocalVars() && 907 "Point should contain all vars except locals!"); 908 assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() && 909 "This function depends on locals being stored last!"); 910 IntegerRelation copy = *this; 911 copy.setAndEliminate(0, point); 912 return copy.findIntegerSample(); 913 } 914 915 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const { 916 std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalVars()); 917 SmallVector<unsigned, 4> denominators(getNumLocalVars()); 918 getLocalReprs(dividends, denominators, repr); 919 } 920 921 void IntegerRelation::getLocalReprs( 922 std::vector<SmallVector<int64_t, 8>> ÷nds, 923 SmallVector<unsigned, 4> &denominators) const { 924 std::vector<MaybeLocalRepr> repr(getNumLocalVars()); 925 getLocalReprs(dividends, denominators, repr); 926 } 927 928 void IntegerRelation::getLocalReprs( 929 std::vector<SmallVector<int64_t, 8>> ÷nds, 930 SmallVector<unsigned, 4> &denominators, 931 std::vector<MaybeLocalRepr> &repr) const { 932 933 repr.resize(getNumLocalVars()); 934 dividends.resize(getNumLocalVars()); 935 denominators.resize(getNumLocalVars()); 936 937 SmallVector<bool, 8> foundRepr(getNumVars(), false); 938 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) 939 foundRepr[i] = true; 940 941 unsigned divOffset = getNumDimAndSymbolVars(); 942 bool changed; 943 do { 944 // Each time changed is true, at end of this iteration, one or more local 945 // vars have been detected as floor divs. 946 changed = false; 947 for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) { 948 if (!foundRepr[i + divOffset]) { 949 MaybeLocalRepr res = computeSingleVarRepr( 950 *this, foundRepr, divOffset + i, dividends[i], denominators[i]); 951 if (!res) 952 continue; 953 foundRepr[i + divOffset] = true; 954 repr[i] = res; 955 changed = true; 956 } 957 } 958 } while (changed); 959 960 // Set 0 denominator for variables for which no division representation 961 // could be found. 962 for (unsigned i = 0, e = repr.size(); i < e; ++i) 963 if (!repr[i]) 964 denominators[i] = 0; 965 } 966 967 /// Tightens inequalities given that we are dealing with integer spaces. This is 968 /// analogous to the GCD test but applied to inequalities. The constant term can 969 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., 970 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a 971 /// fast method - linear in the number of coefficients. 972 // Example on how this affects practical cases: consider the scenario: 973 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield 974 // j >= 100 instead of the tighter (exact) j >= 128. 975 void IntegerRelation::gcdTightenInequalities() { 976 unsigned numCols = getNumCols(); 977 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 978 // Normalize the constraint and tighten the constant term by the GCD. 979 int64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1); 980 if (gcd > 1) 981 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd); 982 } 983 } 984 985 // Eliminates all variable variables in column range [posStart, posLimit). 986 // Returns the number of variables eliminated. 987 unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, 988 unsigned posLimit) { 989 // Return if variable positions to eliminate are out of range. 990 assert(posLimit <= getNumVars()); 991 assert(hasConsistentState()); 992 993 if (posStart >= posLimit) 994 return 0; 995 996 gcdTightenInequalities(); 997 998 unsigned pivotCol = 0; 999 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { 1000 // Find a row which has a non-zero coefficient in column 'j'. 1001 unsigned pivotRow; 1002 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) { 1003 // No pivot row in equalities with non-zero at 'pivotCol'. 1004 if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) { 1005 // If inequalities are also non-zero in 'pivotCol', it can be 1006 // eliminated. 1007 continue; 1008 } 1009 break; 1010 } 1011 1012 // Eliminate variable at 'pivotCol' from each equality row. 1013 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 1014 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 1015 /*isEq=*/true); 1016 equalities.normalizeRow(i); 1017 } 1018 1019 // Eliminate variable at 'pivotCol' from each inequality row. 1020 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 1021 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 1022 /*isEq=*/false); 1023 inequalities.normalizeRow(i); 1024 } 1025 removeEquality(pivotRow); 1026 gcdTightenInequalities(); 1027 } 1028 // Update position limit based on number eliminated. 1029 posLimit = pivotCol; 1030 // Remove eliminated columns from all constraints. 1031 removeVarRange(posStart, posLimit); 1032 return posLimit - posStart; 1033 } 1034 1035 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin 1036 // to check if a constraint is redundant. 1037 void IntegerRelation::removeRedundantInequalities() { 1038 SmallVector<bool, 32> redun(getNumInequalities(), false); 1039 // To check if an inequality is redundant, we replace the inequality by its 1040 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting 1041 // system is empty. If it is, the inequality is redundant. 1042 IntegerRelation tmpCst(*this); 1043 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1044 // Change the inequality to its complement. 1045 tmpCst.inequalities.negateRow(r); 1046 --tmpCst.atIneq(r, tmpCst.getNumCols() - 1); 1047 if (tmpCst.isEmpty()) { 1048 redun[r] = true; 1049 // Zero fill the redundant inequality. 1050 inequalities.fillRow(r, /*value=*/0); 1051 tmpCst.inequalities.fillRow(r, /*value=*/0); 1052 } else { 1053 // Reverse the change (to avoid recreating tmpCst each time). 1054 ++tmpCst.atIneq(r, tmpCst.getNumCols() - 1); 1055 tmpCst.inequalities.negateRow(r); 1056 } 1057 } 1058 1059 unsigned pos = 0; 1060 for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { 1061 if (!redun[r]) 1062 inequalities.copyRow(r, pos++); 1063 } 1064 inequalities.resizeVertically(pos); 1065 } 1066 1067 // A more complex check to eliminate redundant inequalities and equalities. Uses 1068 // Simplex to check if a constraint is redundant. 1069 void IntegerRelation::removeRedundantConstraints() { 1070 // First, we run gcdTightenInequalities. This allows us to catch some 1071 // constraints which are not redundant when considering rational solutions 1072 // but are redundant in terms of integer solutions. 1073 gcdTightenInequalities(); 1074 Simplex simplex(*this); 1075 simplex.detectRedundant(); 1076 1077 unsigned pos = 0; 1078 unsigned numIneqs = getNumInequalities(); 1079 // Scan to get rid of all inequalities marked redundant, in-place. In Simplex, 1080 // the first constraints added are the inequalities. 1081 for (unsigned r = 0; r < numIneqs; r++) { 1082 if (!simplex.isMarkedRedundant(r)) 1083 inequalities.copyRow(r, pos++); 1084 } 1085 inequalities.resizeVertically(pos); 1086 1087 // Scan to get rid of all equalities marked redundant, in-place. In Simplex, 1088 // after the inequalities, a pair of constraints for each equality is added. 1089 // An equality is redundant if both the inequalities in its pair are 1090 // redundant. 1091 pos = 0; 1092 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1093 if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) && 1094 simplex.isMarkedRedundant(numIneqs + 2 * r + 1))) 1095 equalities.copyRow(r, pos++); 1096 } 1097 equalities.resizeVertically(pos); 1098 } 1099 1100 Optional<uint64_t> IntegerRelation::computeVolume() const { 1101 assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!"); 1102 1103 Simplex simplex(*this); 1104 // If the polytope is rationally empty, there are certainly no integer 1105 // points. 1106 if (simplex.isEmpty()) 1107 return 0; 1108 1109 // Just find the maximum and minimum integer value of each non-local var 1110 // separately, thus finding the number of integer values each such var can 1111 // take. Multiplying these together gives a valid overapproximation of the 1112 // number of integer points in the relation. The result this gives is 1113 // equivalent to projecting (rationally) the relation onto its non-local vars 1114 // and returning the number of integer points in a minimal axis-parallel 1115 // hyperrectangular overapproximation of that. 1116 // 1117 // We also handle the special case where one dimension is unbounded and 1118 // another dimension can take no integer values. In this case, the volume is 1119 // zero. 1120 // 1121 // If there is no such empty dimension, if any dimension is unbounded we 1122 // just return the result as unbounded. 1123 uint64_t count = 1; 1124 SmallVector<int64_t, 8> dim(getNumVars() + 1); 1125 bool hasUnboundedVar = false; 1126 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) { 1127 dim[i] = 1; 1128 MaybeOptimum<int64_t> min, max; 1129 std::tie(min, max) = simplex.computeIntegerBounds(dim); 1130 dim[i] = 0; 1131 1132 assert((!min.isEmpty() && !max.isEmpty()) && 1133 "Polytope should be rationally non-empty!"); 1134 1135 // One of the dimensions is unbounded. Note this fact. We will return 1136 // unbounded if none of the other dimensions makes the volume zero. 1137 if (min.isUnbounded() || max.isUnbounded()) { 1138 hasUnboundedVar = true; 1139 continue; 1140 } 1141 1142 // In this case there are no valid integer points and the volume is 1143 // definitely zero. 1144 if (min.getBoundedOptimum() > max.getBoundedOptimum()) 1145 return 0; 1146 1147 count *= (*max - *min + 1); 1148 } 1149 1150 if (count == 0) 1151 return 0; 1152 if (hasUnboundedVar) 1153 return {}; 1154 return count; 1155 } 1156 1157 void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) { 1158 assert(posA < getNumLocalVars() && "Invalid local var position"); 1159 assert(posB < getNumLocalVars() && "Invalid local var position"); 1160 1161 unsigned localOffset = getVarKindOffset(VarKind::Local); 1162 posA += localOffset; 1163 posB += localOffset; 1164 inequalities.addToColumn(posB, posA, 1); 1165 equalities.addToColumn(posB, posA, 1); 1166 removeVar(posB); 1167 } 1168 1169 /// Adds additional local ids to the sets such that they both have the union 1170 /// of the local ids in each set, without changing the set of points that 1171 /// lie in `this` and `other`. 1172 /// 1173 /// To detect local ids that always take the same value, each local id is 1174 /// represented as a floordiv with constant denominator in terms of other ids. 1175 /// After extracting these divisions, local ids in `other` with the same 1176 /// division representation as some other local id in any set are considered 1177 /// duplicate and are merged. 1178 /// 1179 /// It is possible that division representation for some local id cannot be 1180 /// obtained, and thus these local ids are not considered for detecting 1181 /// duplicates. 1182 unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) { 1183 IntegerRelation &relA = *this; 1184 IntegerRelation &relB = other; 1185 1186 unsigned oldALocals = relA.getNumLocalVars(); 1187 1188 // Merge function that merges the local variables in both sets by treating 1189 // them as the same variable. 1190 auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool { 1191 // We only merge from local at pos j to local at pos i, where j > i. 1192 if (i >= j) 1193 return false; 1194 1195 // If i < oldALocals, we are trying to merge duplicate divs. Since we do not 1196 // want to merge duplicates in A, we ignore this call. 1197 if (j < oldALocals) 1198 return false; 1199 1200 // Merge local at pos j into local at position i. 1201 relA.eliminateRedundantLocalVar(i, j); 1202 relB.eliminateRedundantLocalVar(i, j); 1203 return true; 1204 }; 1205 1206 presburger::mergeLocalVars(*this, other, merge); 1207 1208 // Since we do not remove duplicate divisions in relA, this is guranteed to be 1209 // non-negative. 1210 return relA.getNumLocalVars() - oldALocals; 1211 } 1212 1213 bool IntegerRelation::hasOnlyDivLocals() const { 1214 std::vector<MaybeLocalRepr> reprs; 1215 getLocalReprs(reprs); 1216 return llvm::all_of(reprs, 1217 [](const MaybeLocalRepr &repr) { return bool(repr); }); 1218 } 1219 1220 void IntegerRelation::removeDuplicateDivs() { 1221 std::vector<SmallVector<int64_t, 8>> divs; 1222 SmallVector<unsigned, 4> denoms; 1223 1224 getLocalReprs(divs, denoms); 1225 auto merge = [this](unsigned i, unsigned j) -> bool { 1226 eliminateRedundantLocalVar(i, j); 1227 return true; 1228 }; 1229 presburger::removeDuplicateDivs(divs, denoms, 1230 getVarKindOffset(VarKind::Local), merge); 1231 } 1232 1233 /// Removes local variables using equalities. Each equality is checked if it 1234 /// can be reduced to the form: `e = affine-expr`, where `e` is a local 1235 /// variable and `affine-expr` is an affine expression not containing `e`. 1236 /// If an equality satisfies this form, the local variable is replaced in 1237 /// each constraint and then removed. The equality used to replace this local 1238 /// variable is also removed. 1239 void IntegerRelation::removeRedundantLocalVars() { 1240 // Normalize the equality constraints to reduce coefficients of local 1241 // variables to 1 wherever possible. 1242 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1243 equalities.normalizeRow(i); 1244 1245 while (true) { 1246 unsigned i, e, j, f; 1247 for (i = 0, e = getNumEqualities(); i < e; ++i) { 1248 // Find a local variable to eliminate using ith equality. 1249 for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j) 1250 if (std::abs(atEq(i, j)) == 1) 1251 break; 1252 1253 // Local variable can be eliminated using ith equality. 1254 if (j < f) 1255 break; 1256 } 1257 1258 // No equality can be used to eliminate a local variable. 1259 if (i == e) 1260 break; 1261 1262 // Use the ith equality to simplify other equalities. If any changes 1263 // are made to an equality constraint, it is normalized by GCD. 1264 for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) { 1265 if (atEq(k, j) != 0) { 1266 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true); 1267 equalities.normalizeRow(k); 1268 } 1269 } 1270 1271 // Use the ith equality to simplify inequalities. 1272 for (unsigned k = 0, t = getNumInequalities(); k < t; ++k) 1273 eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false); 1274 1275 // Remove the ith equality and the found local variable. 1276 removeVar(j); 1277 removeEquality(i); 1278 } 1279 } 1280 1281 void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart, 1282 unsigned varLimit, VarKind dstKind, 1283 unsigned pos) { 1284 assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range"); 1285 1286 if (varStart >= varLimit) 1287 return; 1288 1289 // Append new local variables corresponding to the dimensions to be converted. 1290 unsigned convertCount = varLimit - varStart; 1291 unsigned newVarsBegin = insertVar(dstKind, pos, convertCount); 1292 1293 // Swap the new local variables with dimensions. 1294 // 1295 // Essentially, this moves the information corresponding to the specified ids 1296 // of kind `srcKind` to the `convertCount` newly created ids of kind 1297 // `dstKind`. In particular, this moves the columns in the constraint 1298 // matrices, and zeros out the initially occupied columns (because the newly 1299 // created ids we're swapping with were zero-initialized). 1300 unsigned offset = getVarKindOffset(srcKind); 1301 for (unsigned i = 0; i < convertCount; ++i) 1302 swapVar(offset + varStart + i, newVarsBegin + i); 1303 1304 // Complete the move by deleting the initially occupied columns. 1305 removeVarRange(srcKind, varStart, varLimit); 1306 } 1307 1308 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { 1309 assert(pos < getNumCols()); 1310 if (type == BoundType::EQ) { 1311 unsigned row = equalities.appendExtraRow(); 1312 equalities(row, pos) = 1; 1313 equalities(row, getNumCols() - 1) = -value; 1314 } else { 1315 unsigned row = inequalities.appendExtraRow(); 1316 inequalities(row, pos) = type == BoundType::LB ? 1 : -1; 1317 inequalities(row, getNumCols() - 1) = 1318 type == BoundType::LB ? -value : value; 1319 } 1320 } 1321 1322 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr, 1323 int64_t value) { 1324 assert(type != BoundType::EQ && "EQ not implemented"); 1325 assert(expr.size() == getNumCols()); 1326 unsigned row = inequalities.appendExtraRow(); 1327 for (unsigned i = 0, e = expr.size(); i < e; ++i) 1328 inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i]; 1329 inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += 1330 type == BoundType::LB ? -value : value; 1331 } 1332 1333 /// Adds a new local variable as the floordiv of an affine function of other 1334 /// variables, the coefficients of which are provided in 'dividend' and with 1335 /// respect to a positive constant 'divisor'. Two constraints are added to the 1336 /// system to capture equivalence with the floordiv. 1337 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. 1338 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend, 1339 int64_t divisor) { 1340 assert(dividend.size() == getNumCols() && "incorrect dividend size"); 1341 assert(divisor > 0 && "positive divisor expected"); 1342 1343 appendVar(VarKind::Local); 1344 1345 SmallVector<int64_t, 8> dividendCopy(dividend.begin(), dividend.end()); 1346 dividendCopy.insert(dividendCopy.end() - 1, 0); 1347 addInequality( 1348 getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); 1349 addInequality( 1350 getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); 1351 } 1352 1353 /// Finds an equality that equates the specified variable to a constant. 1354 /// Returns the position of the equality row. If 'symbolic' is set to true, 1355 /// symbols are also treated like a constant, i.e., an affine function of the 1356 /// symbols is also treated like a constant. Returns -1 if such an equality 1357 /// could not be found. 1358 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, 1359 bool symbolic = false) { 1360 assert(pos < cst.getNumVars() && "invalid position"); 1361 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1362 int64_t v = cst.atEq(r, pos); 1363 if (v * v != 1) 1364 continue; 1365 unsigned c; 1366 unsigned f = symbolic ? cst.getNumDimVars() : cst.getNumVars(); 1367 // This checks for zeros in all positions other than 'pos' in [0, f) 1368 for (c = 0; c < f; c++) { 1369 if (c == pos) 1370 continue; 1371 if (cst.atEq(r, c) != 0) { 1372 // Dependent on another variable. 1373 break; 1374 } 1375 } 1376 if (c == f) 1377 // Equality is free of other variables. 1378 return r; 1379 } 1380 return -1; 1381 } 1382 1383 LogicalResult IntegerRelation::constantFoldVar(unsigned pos) { 1384 assert(pos < getNumVars() && "invalid position"); 1385 int rowIdx; 1386 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) 1387 return failure(); 1388 1389 // atEq(rowIdx, pos) is either -1 or 1. 1390 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); 1391 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); 1392 setAndEliminate(pos, constVal); 1393 return success(); 1394 } 1395 1396 void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) { 1397 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { 1398 if (failed(constantFoldVar(t))) 1399 t++; 1400 } 1401 } 1402 1403 /// Returns a non-negative constant bound on the extent (upper bound - lower 1404 /// bound) of the specified variable if it is found to be a constant; returns 1405 /// None if it's not a constant. This methods treats symbolic variables 1406 /// specially, i.e., it looks for constant differences between affine 1407 /// expressions involving only the symbolic variables. See comments at 1408 /// function definition for example. 'lb', if provided, is set to the lower 1409 /// bound associated with the constant difference. Note that 'lb' is purely 1410 /// symbolic and thus will contain the coefficients of the symbolic variables 1411 /// and the constant coefficient. 1412 // Egs: 0 <= i <= 15, return 16. 1413 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) 1414 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. 1415 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = 1416 // ceil(s0 - 7 / 8) = floor(s0 / 8)). 1417 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize( 1418 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor, 1419 SmallVectorImpl<int64_t> *ub, unsigned *minLbPos, 1420 unsigned *minUbPos) const { 1421 assert(pos < getNumDimVars() && "Invalid variable position"); 1422 1423 // Find an equality for 'pos'^th variable that equates it to some function 1424 // of the symbolic variables (+ constant). 1425 int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); 1426 if (eqPos != -1) { 1427 auto eq = getEquality(eqPos); 1428 // If the equality involves a local var, punt for now. 1429 // TODO: this can be handled in the future by using the explicit 1430 // representation of the local vars. 1431 if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1, 1432 [](int64_t coeff) { return coeff == 0; })) 1433 return None; 1434 1435 // This variable can only take a single value. 1436 if (lb) { 1437 // Set lb to that symbolic value. 1438 lb->resize(getNumSymbolVars() + 1); 1439 if (ub) 1440 ub->resize(getNumSymbolVars() + 1); 1441 for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) { 1442 int64_t v = atEq(eqPos, pos); 1443 // atEq(eqRow, pos) is either -1 or 1. 1444 assert(v * v == 1); 1445 (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v 1446 : -atEq(eqPos, getNumDimVars() + c) / v; 1447 // Since this is an equality, ub = lb. 1448 if (ub) 1449 (*ub)[c] = (*lb)[c]; 1450 } 1451 assert(boundFloorDivisor && 1452 "both lb and divisor or none should be provided"); 1453 *boundFloorDivisor = 1; 1454 } 1455 if (minLbPos) 1456 *minLbPos = eqPos; 1457 if (minUbPos) 1458 *minUbPos = eqPos; 1459 return 1; 1460 } 1461 1462 // Check if the variable appears at all in any of the inequalities. 1463 unsigned r, e; 1464 for (r = 0, e = getNumInequalities(); r < e; r++) { 1465 if (atIneq(r, pos) != 0) 1466 break; 1467 } 1468 if (r == e) 1469 // If it doesn't, there isn't a bound on it. 1470 return None; 1471 1472 // Positions of constraints that are lower/upper bounds on the variable. 1473 SmallVector<unsigned, 4> lbIndices, ubIndices; 1474 1475 // Gather all symbolic lower bounds and upper bounds of the variable, i.e., 1476 // the bounds can only involve symbolic (and local) variables. Since the 1477 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1478 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1479 getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, 1480 /*eqIndices=*/nullptr, /*offset=*/0, 1481 /*num=*/getNumDimVars()); 1482 1483 Optional<int64_t> minDiff = None; 1484 unsigned minLbPosition = 0, minUbPosition = 0; 1485 for (auto ubPos : ubIndices) { 1486 for (auto lbPos : lbIndices) { 1487 // Look for a lower bound and an upper bound that only differ by a 1488 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. 1489 // For example, if ii is the pos^th variable, we are looking for 1490 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The 1491 // minimum among all such constant differences is kept since that's the 1492 // constant bounding the extent of the pos^th variable. 1493 unsigned j, e; 1494 for (j = 0, e = getNumCols() - 1; j < e; j++) 1495 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { 1496 break; 1497 } 1498 if (j < getNumCols() - 1) 1499 continue; 1500 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + 1501 atIneq(lbPos, getNumCols() - 1) + 1, 1502 atIneq(lbPos, pos)); 1503 // This bound is non-negative by definition. 1504 diff = std::max<int64_t>(diff, 0); 1505 if (minDiff == None || diff < minDiff) { 1506 minDiff = diff; 1507 minLbPosition = lbPos; 1508 minUbPosition = ubPos; 1509 } 1510 } 1511 } 1512 if (lb && minDiff) { 1513 // Set lb to the symbolic lower bound. 1514 lb->resize(getNumSymbolVars() + 1); 1515 if (ub) 1516 ub->resize(getNumSymbolVars() + 1); 1517 // The lower bound is the ceildiv of the lb constraint over the coefficient 1518 // of the variable at 'pos'. We express the ceildiv equivalently as a floor 1519 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + 1520 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). 1521 *boundFloorDivisor = atIneq(minLbPosition, pos); 1522 assert(*boundFloorDivisor == -atIneq(minUbPosition, pos)); 1523 for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) { 1524 (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c); 1525 } 1526 if (ub) { 1527 for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) 1528 (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c); 1529 } 1530 // The lower bound leads to a ceildiv while the upper bound is a floordiv 1531 // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val + 1532 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to 1533 // the constant term for the lower bound. 1534 (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1; 1535 } 1536 if (minLbPos) 1537 *minLbPos = minLbPosition; 1538 if (minUbPos) 1539 *minUbPos = minUbPosition; 1540 return minDiff; 1541 } 1542 1543 template <bool isLower> 1544 Optional<int64_t> 1545 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { 1546 assert(pos < getNumVars() && "invalid position"); 1547 // Project to 'pos'. 1548 projectOut(0, pos); 1549 projectOut(1, getNumVars() - 1); 1550 // Check if there's an equality equating the '0'^th variable to a constant. 1551 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); 1552 if (eqRowIdx != -1) 1553 // atEq(rowIdx, 0) is either -1 or 1. 1554 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); 1555 1556 // Check if the variable appears at all in any of the inequalities. 1557 unsigned r, e; 1558 for (r = 0, e = getNumInequalities(); r < e; r++) { 1559 if (atIneq(r, 0) != 0) 1560 break; 1561 } 1562 if (r == e) 1563 // If it doesn't, there isn't a bound on it. 1564 return None; 1565 1566 Optional<int64_t> minOrMaxConst = None; 1567 1568 // Take the max across all const lower bounds (or min across all constant 1569 // upper bounds). 1570 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1571 if (isLower) { 1572 if (atIneq(r, 0) <= 0) 1573 // Not a lower bound. 1574 continue; 1575 } else if (atIneq(r, 0) >= 0) { 1576 // Not an upper bound. 1577 continue; 1578 } 1579 unsigned c, f; 1580 for (c = 0, f = getNumCols() - 1; c < f; c++) 1581 if (c != 0 && atIneq(r, c) != 0) 1582 break; 1583 if (c < getNumCols() - 1) 1584 // Not a constant bound. 1585 continue; 1586 1587 int64_t boundConst = 1588 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) 1589 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); 1590 if (isLower) { 1591 if (minOrMaxConst == None || boundConst > minOrMaxConst) 1592 minOrMaxConst = boundConst; 1593 } else { 1594 if (minOrMaxConst == None || boundConst < minOrMaxConst) 1595 minOrMaxConst = boundConst; 1596 } 1597 } 1598 return minOrMaxConst; 1599 } 1600 1601 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type, 1602 unsigned pos) const { 1603 if (type == BoundType::LB) 1604 return IntegerRelation(*this) 1605 .computeConstantLowerOrUpperBound</*isLower=*/true>(pos); 1606 if (type == BoundType::UB) 1607 return IntegerRelation(*this) 1608 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1609 1610 assert(type == BoundType::EQ && "expected EQ"); 1611 Optional<int64_t> lb = 1612 IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>( 1613 pos); 1614 Optional<int64_t> ub = 1615 IntegerRelation(*this) 1616 .computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 1617 return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None; 1618 } 1619 1620 // A simple (naive and conservative) check for hyper-rectangularity. 1621 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const { 1622 assert(pos < getNumCols() - 1); 1623 // Check for two non-zero coefficients in the range [pos, pos + sum). 1624 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1625 unsigned sum = 0; 1626 for (unsigned c = pos; c < pos + num; c++) { 1627 if (atIneq(r, c) != 0) 1628 sum++; 1629 } 1630 if (sum > 1) 1631 return false; 1632 } 1633 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1634 unsigned sum = 0; 1635 for (unsigned c = pos; c < pos + num; c++) { 1636 if (atEq(r, c) != 0) 1637 sum++; 1638 } 1639 if (sum > 1) 1640 return false; 1641 } 1642 return true; 1643 } 1644 1645 /// Removes duplicate constraints, trivially true constraints, and constraints 1646 /// that can be detected as redundant as a result of differing only in their 1647 /// constant term part. A constraint of the form <non-negative constant> >= 0 is 1648 /// considered trivially true. 1649 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to 1650 // remove duplicates in place. 1651 void IntegerRelation::removeTrivialRedundancy() { 1652 gcdTightenInequalities(); 1653 normalizeConstraintsByGCD(); 1654 1655 // A map used to detect redundancy stemming from constraints that only differ 1656 // in their constant term. The value stored is <row position, const term> 1657 // for a given row. 1658 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>> 1659 rowsWithoutConstTerm; 1660 // To unique rows. 1661 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet; 1662 1663 // Check if constraint is of the form <non-negative-constant> >= 0. 1664 auto isTriviallyValid = [&](unsigned r) -> bool { 1665 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { 1666 if (atIneq(r, c) != 0) 1667 return false; 1668 } 1669 return atIneq(r, getNumCols() - 1) >= 0; 1670 }; 1671 1672 // Detect and mark redundant constraints. 1673 SmallVector<bool, 256> redunIneq(getNumInequalities(), false); 1674 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1675 int64_t *rowStart = &inequalities(r, 0); 1676 auto row = ArrayRef<int64_t>(rowStart, getNumCols()); 1677 if (isTriviallyValid(r) || !rowSet.insert(row).second) { 1678 redunIneq[r] = true; 1679 continue; 1680 } 1681 1682 // Among constraints that only differ in the constant term part, mark 1683 // everything other than the one with the smallest constant term redundant. 1684 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the 1685 // former two are redundant). 1686 int64_t constTerm = atIneq(r, getNumCols() - 1); 1687 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1); 1688 const auto &ret = 1689 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); 1690 if (!ret.second) { 1691 // Check if the other constraint has a higher constant term. 1692 auto &val = ret.first->second; 1693 if (val.second > constTerm) { 1694 // The stored row is redundant. Mark it so, and update with this one. 1695 redunIneq[val.first] = true; 1696 val = {r, constTerm}; 1697 } else { 1698 // The one stored makes this one redundant. 1699 redunIneq[r] = true; 1700 } 1701 } 1702 } 1703 1704 // Scan to get rid of all rows marked redundant, in-place. 1705 unsigned pos = 0; 1706 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) 1707 if (!redunIneq[r]) 1708 inequalities.copyRow(r, pos++); 1709 1710 inequalities.resizeVertically(pos); 1711 1712 // TODO: consider doing this for equalities as well, but probably not worth 1713 // the savings. 1714 } 1715 1716 #undef DEBUG_TYPE 1717 #define DEBUG_TYPE "fm" 1718 1719 /// Eliminates variable at the specified position using Fourier-Motzkin 1720 /// variable elimination. This technique is exact for rational spaces but 1721 /// conservative (in "rare" cases) for integer spaces. The operation corresponds 1722 /// to a projection operation yielding the (convex) set of integer points 1723 /// contained in the rational shadow of the set. An emptiness test that relies 1724 /// on this method will guarantee emptiness, i.e., it disproves the existence of 1725 /// a solution if it says it's empty. 1726 /// If a non-null isResultIntegerExact is passed, it is set to true if the 1727 /// result is also integer exact. If it's set to false, the obtained solution 1728 /// *may* not be exact, i.e., it may contain integer points that do not have an 1729 /// integer pre-image in the original set. 1730 /// 1731 /// Eg: 1732 /// j >= 0, j <= i + 1 1733 /// i >= 0, i <= N + 1 1734 /// Eliminating i yields, 1735 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1 1736 /// 1737 /// If darkShadow = true, this method computes the dark shadow on elimination; 1738 /// the dark shadow is a convex integer subset of the exact integer shadow. A 1739 /// non-empty dark shadow proves the existence of an integer solution. The 1740 /// elimination in such a case could however be an under-approximation, and thus 1741 /// should not be used for scanning sets or used by itself for dependence 1742 /// checking. 1743 /// 1744 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. 1745 /// ^ 1746 /// | 1747 /// | * * * * o o 1748 /// i | * * o o o o 1749 /// | o * * * * * 1750 /// ---------------> 1751 /// j -> 1752 /// 1753 /// Eliminating i from this system (projecting on the j dimension): 1754 /// rational shadow / integer light shadow: 1 <= j <= 6 1755 /// dark shadow: 3 <= j <= 6 1756 /// exact integer shadow: j = 1 \union 3 <= j <= 6 1757 /// holes/splinters: j = 2 1758 /// 1759 /// darkShadow = false, isResultIntegerExact = nullptr are default values. 1760 // TODO: a slight modification to yield dark shadow version of FM (tightened), 1761 // which can prove the existence of a solution if there is one. 1762 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, 1763 bool *isResultIntegerExact) { 1764 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); 1765 LLVM_DEBUG(dump()); 1766 assert(pos < getNumVars() && "invalid position"); 1767 assert(hasConsistentState()); 1768 1769 // Check if this variable can be eliminated through a substitution. 1770 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1771 if (atEq(r, pos) != 0) { 1772 // Use Gaussian elimination here (since we have an equality). 1773 LogicalResult ret = gaussianEliminateVar(pos); 1774 (void)ret; 1775 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); 1776 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); 1777 LLVM_DEBUG(dump()); 1778 return; 1779 } 1780 } 1781 1782 // A fast linear time tightening. 1783 gcdTightenInequalities(); 1784 1785 // Check if the variable appears at all in any of the inequalities. 1786 if (isColZero(pos)) { 1787 // If it doesn't appear, just remove the column and return. 1788 // TODO: refactor removeColumns to use it from here. 1789 removeVar(pos); 1790 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1791 LLVM_DEBUG(dump()); 1792 return; 1793 } 1794 1795 // Positions of constraints that are lower bounds on the variable. 1796 SmallVector<unsigned, 4> lbIndices; 1797 // Positions of constraints that are lower bounds on the variable. 1798 SmallVector<unsigned, 4> ubIndices; 1799 // Positions of constraints that do not involve the variable. 1800 std::vector<unsigned> nbIndices; 1801 nbIndices.reserve(getNumInequalities()); 1802 1803 // Gather all lower bounds and upper bounds of the variable. Since the 1804 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1805 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1806 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1807 if (atIneq(r, pos) == 0) { 1808 // Var does not appear in bound. 1809 nbIndices.push_back(r); 1810 } else if (atIneq(r, pos) >= 1) { 1811 // Lower bound. 1812 lbIndices.push_back(r); 1813 } else { 1814 // Upper bound. 1815 ubIndices.push_back(r); 1816 } 1817 } 1818 1819 PresburgerSpace newSpace = getSpace(); 1820 VarKind idKindRemove = newSpace.getVarKindAt(pos); 1821 unsigned relativePos = pos - newSpace.getVarKindOffset(idKindRemove); 1822 newSpace.removeVarRange(idKindRemove, relativePos, relativePos + 1); 1823 1824 /// Create the new system which has one variable less. 1825 IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(), 1826 getNumEqualities(), getNumCols() - 1, newSpace); 1827 1828 // This will be used to check if the elimination was integer exact. 1829 unsigned lcmProducts = 1; 1830 1831 // Let x be the variable we are eliminating. 1832 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note 1833 // that c_l, c_u >= 1) we have: 1834 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u 1835 // We thus generate a constraint: 1836 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. 1837 // Note if c_l = c_u = 1, all integer points captured by the resulting 1838 // constraint correspond to integer points in the original system (i.e., they 1839 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is 1840 // integer exact. 1841 for (auto ubPos : ubIndices) { 1842 for (auto lbPos : lbIndices) { 1843 SmallVector<int64_t, 4> ineq; 1844 ineq.reserve(newRel.getNumCols()); 1845 int64_t lbCoeff = atIneq(lbPos, pos); 1846 // Note that in the comments above, ubCoeff is the negation of the 1847 // coefficient in the canonical form as the view taken here is that of the 1848 // term being moved to the other size of '>='. 1849 int64_t ubCoeff = -atIneq(ubPos, pos); 1850 // TODO: refactor this loop to avoid all branches inside. 1851 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1852 if (l == pos) 1853 continue; 1854 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); 1855 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); 1856 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + 1857 atIneq(lbPos, l) * (lcm / lbCoeff)); 1858 lcmProducts *= lcm; 1859 } 1860 if (darkShadow) { 1861 // The dark shadow is a convex subset of the exact integer shadow. If 1862 // there is a point here, it proves the existence of a solution. 1863 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; 1864 } 1865 // TODO: we need to have a way to add inequalities in-place in 1866 // IntegerRelation instead of creating and copying over. 1867 newRel.addInequality(ineq); 1868 } 1869 } 1870 1871 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1) 1872 << "\n"); 1873 if (lcmProducts == 1 && isResultIntegerExact) 1874 *isResultIntegerExact = true; 1875 1876 // Copy over the constraints not involving this variable. 1877 for (auto nbPos : nbIndices) { 1878 SmallVector<int64_t, 4> ineq; 1879 ineq.reserve(getNumCols() - 1); 1880 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1881 if (l == pos) 1882 continue; 1883 ineq.push_back(atIneq(nbPos, l)); 1884 } 1885 newRel.addInequality(ineq); 1886 } 1887 1888 assert(newRel.getNumConstraints() == 1889 lbIndices.size() * ubIndices.size() + nbIndices.size()); 1890 1891 // Copy over the equalities. 1892 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1893 SmallVector<int64_t, 4> eq; 1894 eq.reserve(newRel.getNumCols()); 1895 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 1896 if (l == pos) 1897 continue; 1898 eq.push_back(atEq(r, l)); 1899 } 1900 newRel.addEquality(eq); 1901 } 1902 1903 // GCD tightening and normalization allows detection of more trivially 1904 // redundant constraints. 1905 newRel.gcdTightenInequalities(); 1906 newRel.normalizeConstraintsByGCD(); 1907 newRel.removeTrivialRedundancy(); 1908 clearAndCopyFrom(newRel); 1909 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 1910 LLVM_DEBUG(dump()); 1911 } 1912 1913 #undef DEBUG_TYPE 1914 #define DEBUG_TYPE "presburger" 1915 1916 void IntegerRelation::projectOut(unsigned pos, unsigned num) { 1917 if (num == 0) 1918 return; 1919 1920 // 'pos' can be at most getNumCols() - 2 if num > 0. 1921 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position"); 1922 assert(pos + num < getNumCols() && "invalid range"); 1923 1924 // Eliminate as many variables as possible using Gaussian elimination. 1925 unsigned currentPos = pos; 1926 unsigned numToEliminate = num; 1927 unsigned numGaussianEliminated = 0; 1928 1929 while (currentPos < getNumVars()) { 1930 unsigned curNumEliminated = 1931 gaussianEliminateVars(currentPos, currentPos + numToEliminate); 1932 ++currentPos; 1933 numToEliminate -= curNumEliminated + 1; 1934 numGaussianEliminated += curNumEliminated; 1935 } 1936 1937 // Eliminate the remaining using Fourier-Motzkin. 1938 for (unsigned i = 0; i < num - numGaussianEliminated; i++) { 1939 unsigned numToEliminate = num - numGaussianEliminated - i; 1940 fourierMotzkinEliminate( 1941 getBestVarToEliminate(*this, pos, pos + numToEliminate)); 1942 } 1943 1944 // Fast/trivial simplifications. 1945 gcdTightenInequalities(); 1946 // Normalize constraints after tightening since the latter impacts this, but 1947 // not the other way round. 1948 normalizeConstraintsByGCD(); 1949 } 1950 1951 namespace { 1952 1953 enum BoundCmpResult { Greater, Less, Equal, Unknown }; 1954 1955 /// Compares two affine bounds whose coefficients are provided in 'first' and 1956 /// 'second'. The last coefficient is the constant term. 1957 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 1958 assert(a.size() == b.size()); 1959 1960 // For the bounds to be comparable, their corresponding variable 1961 // coefficients should be equal; the constant terms are then compared to 1962 // determine less/greater/equal. 1963 1964 if (!std::equal(a.begin(), a.end() - 1, b.begin())) 1965 return Unknown; 1966 1967 if (a.back() == b.back()) 1968 return Equal; 1969 1970 return a.back() < b.back() ? Less : Greater; 1971 } 1972 } // namespace 1973 1974 // Returns constraints that are common to both A & B. 1975 static void getCommonConstraints(const IntegerRelation &a, 1976 const IntegerRelation &b, IntegerRelation &c) { 1977 c = IntegerRelation(a.getSpace()); 1978 // a naive O(n^2) check should be enough here given the input sizes. 1979 for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) { 1980 for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) { 1981 if (a.getInequality(r) == b.getInequality(s)) { 1982 c.addInequality(a.getInequality(r)); 1983 break; 1984 } 1985 } 1986 } 1987 for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) { 1988 for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) { 1989 if (a.getEquality(r) == b.getEquality(s)) { 1990 c.addEquality(a.getEquality(r)); 1991 break; 1992 } 1993 } 1994 } 1995 } 1996 1997 // Computes the bounding box with respect to 'other' by finding the min of the 1998 // lower bounds and the max of the upper bounds along each of the dimensions. 1999 LogicalResult 2000 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { 2001 assert(space.isEqual(otherCst.getSpace()) && "Spaces should match."); 2002 assert(getNumLocalVars() == 0 && "local ids not supported yet here"); 2003 2004 // Get the constraints common to both systems; these will be added as is to 2005 // the union. 2006 IntegerRelation commonCst(PresburgerSpace::getRelationSpace()); 2007 getCommonConstraints(*this, otherCst, commonCst); 2008 2009 std::vector<SmallVector<int64_t, 8>> boundingLbs; 2010 std::vector<SmallVector<int64_t, 8>> boundingUbs; 2011 boundingLbs.reserve(2 * getNumDimVars()); 2012 boundingUbs.reserve(2 * getNumDimVars()); 2013 2014 // To hold lower and upper bounds for each dimension. 2015 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; 2016 // To compute min of lower bounds and max of upper bounds for each dimension. 2017 SmallVector<int64_t, 4> minLb(getNumSymbolVars() + 1); 2018 SmallVector<int64_t, 4> maxUb(getNumSymbolVars() + 1); 2019 // To compute final new lower and upper bounds for the union. 2020 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); 2021 2022 int64_t lbFloorDivisor, otherLbFloorDivisor; 2023 for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) { 2024 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); 2025 if (!extent.hasValue()) 2026 // TODO: symbolic extents when necessary. 2027 // TODO: handle union if a dimension is unbounded. 2028 return failure(); 2029 2030 auto otherExtent = otherCst.getConstantBoundOnDimSize( 2031 d, &otherLb, &otherLbFloorDivisor, &otherUb); 2032 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor) 2033 // TODO: symbolic extents when necessary. 2034 return failure(); 2035 2036 assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); 2037 2038 auto res = compareBounds(lb, otherLb); 2039 // Identify min. 2040 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { 2041 minLb = lb; 2042 // Since the divisor is for a floordiv, we need to convert to ceildiv, 2043 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=> 2044 // div * i >= expr - div + 1. 2045 minLb.back() -= lbFloorDivisor - 1; 2046 } else if (res == BoundCmpResult::Greater) { 2047 minLb = otherLb; 2048 minLb.back() -= otherLbFloorDivisor - 1; 2049 } else { 2050 // Uncomparable - check for constant lower/upper bounds. 2051 auto constLb = getConstantBound(BoundType::LB, d); 2052 auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d); 2053 if (!constLb.hasValue() || !constOtherLb.hasValue()) 2054 return failure(); 2055 std::fill(minLb.begin(), minLb.end(), 0); 2056 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); 2057 } 2058 2059 // Do the same for ub's but max of upper bounds. Identify max. 2060 auto uRes = compareBounds(ub, otherUb); 2061 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { 2062 maxUb = ub; 2063 } else if (uRes == BoundCmpResult::Less) { 2064 maxUb = otherUb; 2065 } else { 2066 // Uncomparable - check for constant lower/upper bounds. 2067 auto constUb = getConstantBound(BoundType::UB, d); 2068 auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d); 2069 if (!constUb.hasValue() || !constOtherUb.hasValue()) 2070 return failure(); 2071 std::fill(maxUb.begin(), maxUb.end(), 0); 2072 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); 2073 } 2074 2075 std::fill(newLb.begin(), newLb.end(), 0); 2076 std::fill(newUb.begin(), newUb.end(), 0); 2077 2078 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, 2079 // and so it's the divisor for newLb and newUb as well. 2080 newLb[d] = lbFloorDivisor; 2081 newUb[d] = -lbFloorDivisor; 2082 // Copy over the symbolic part + constant term. 2083 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars()); 2084 std::transform(newLb.begin() + getNumDimVars(), newLb.end(), 2085 newLb.begin() + getNumDimVars(), std::negate<int64_t>()); 2086 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars()); 2087 2088 boundingLbs.push_back(newLb); 2089 boundingUbs.push_back(newUb); 2090 } 2091 2092 // Clear all constraints and add the lower/upper bounds for the bounding box. 2093 clearConstraints(); 2094 for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) { 2095 addInequality(boundingLbs[d]); 2096 addInequality(boundingUbs[d]); 2097 } 2098 2099 // Add the constraints that were common to both systems. 2100 append(commonCst); 2101 removeTrivialRedundancy(); 2102 2103 // TODO: copy over pure symbolic constraints from this and 'other' over to the 2104 // union (since the above are just the union along dimensions); we shouldn't 2105 // be discarding any other constraints on the symbols. 2106 2107 return success(); 2108 } 2109 2110 bool IntegerRelation::isColZero(unsigned pos) const { 2111 unsigned rowPos; 2112 return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) && 2113 !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos); 2114 } 2115 2116 /// Find positions of inequalities and equalities that do not have a coefficient 2117 /// for [pos, pos + num) variables. 2118 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos, 2119 unsigned num, 2120 SmallVectorImpl<unsigned> &nbIneqIndices, 2121 SmallVectorImpl<unsigned> &nbEqIndices) { 2122 assert(pos < cst.getNumVars() && "invalid start position"); 2123 assert(pos + num <= cst.getNumVars() && "invalid limit"); 2124 2125 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 2126 // The bounds are to be independent of [offset, offset + num) columns. 2127 unsigned c; 2128 for (c = pos; c < pos + num; ++c) { 2129 if (cst.atIneq(r, c) != 0) 2130 break; 2131 } 2132 if (c == pos + num) 2133 nbIneqIndices.push_back(r); 2134 } 2135 2136 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 2137 // The bounds are to be independent of [offset, offset + num) columns. 2138 unsigned c; 2139 for (c = pos; c < pos + num; ++c) { 2140 if (cst.atEq(r, c) != 0) 2141 break; 2142 } 2143 if (c == pos + num) 2144 nbEqIndices.push_back(r); 2145 } 2146 } 2147 2148 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) { 2149 assert(pos + num <= getNumVars() && "invalid range"); 2150 2151 // Remove constraints that are independent of these variables. 2152 SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices; 2153 getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices); 2154 2155 // Iterate in reverse so that indices don't have to be updated. 2156 // TODO: This method can be made more efficient (because removal of each 2157 // inequality leads to much shifting/copying in the underlying buffer). 2158 for (auto nbIndex : llvm::reverse(nbIneqIndices)) 2159 removeInequality(nbIndex); 2160 for (auto nbIndex : llvm::reverse(nbEqIndices)) 2161 removeEquality(nbIndex); 2162 } 2163 2164 IntegerPolyhedron IntegerRelation::getDomainSet() const { 2165 IntegerRelation copyRel = *this; 2166 2167 // Convert Range variables to Local variables. 2168 copyRel.convertVarKind(VarKind::Range, 0, getNumVarKind(VarKind::Range), 2169 VarKind::Local); 2170 2171 // Convert Domain variables to SetDim(Range) variables. 2172 copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain), 2173 VarKind::SetDim); 2174 2175 return IntegerPolyhedron(std::move(copyRel)); 2176 } 2177 2178 IntegerPolyhedron IntegerRelation::getRangeSet() const { 2179 IntegerRelation copyRel = *this; 2180 2181 // Convert Domain variables to Local variables. 2182 copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain), 2183 VarKind::Local); 2184 2185 // We do not need to do anything to Range variables since they are already in 2186 // SetDim position. 2187 2188 return IntegerPolyhedron(std::move(copyRel)); 2189 } 2190 2191 void IntegerRelation::intersectDomain(const IntegerPolyhedron &poly) { 2192 assert(getDomainSet().getSpace().isCompatible(poly.getSpace()) && 2193 "Domain set is not compatible with poly"); 2194 2195 // Treating the poly as a relation, convert it from `0 -> R` to `R -> 0`. 2196 IntegerRelation rel = poly; 2197 rel.inverse(); 2198 2199 // Append dummy range variables to make the spaces compatible. 2200 rel.appendVar(VarKind::Range, getNumRangeVars()); 2201 2202 // Intersect in place. 2203 mergeLocalVars(rel); 2204 append(rel); 2205 } 2206 2207 void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) { 2208 assert(getRangeSet().getSpace().isCompatible(poly.getSpace()) && 2209 "Range set is not compatible with poly"); 2210 2211 IntegerRelation rel = poly; 2212 2213 // Append dummy domain variables to make the spaces compatible. 2214 rel.appendVar(VarKind::Domain, getNumDomainVars()); 2215 2216 mergeLocalVars(rel); 2217 append(rel); 2218 } 2219 2220 void IntegerRelation::inverse() { 2221 unsigned numRangeVars = getNumVarKind(VarKind::Range); 2222 convertVarKind(VarKind::Domain, 0, getVarKindEnd(VarKind::Domain), 2223 VarKind::Range); 2224 convertVarKind(VarKind::Range, 0, numRangeVars, VarKind::Domain); 2225 } 2226 2227 void IntegerRelation::compose(const IntegerRelation &rel) { 2228 assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) && 2229 "Range of `this` should be compatible with Domain of `rel`"); 2230 2231 IntegerRelation copyRel = rel; 2232 2233 // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C. 2234 // We convert R1 to A -> (B X C), and R2 to B X C then intersect the range of 2235 // R1 with R2. After this, we get R1: A -> C, by projecting out B. 2236 // TODO: Using nested spaces here would help, since we could directly 2237 // intersect the range with another relation. 2238 unsigned numBVars = getNumRangeVars(); 2239 2240 // Convert R1 from A -> B to A -> (B X C). 2241 appendVar(VarKind::Range, copyRel.getNumRangeVars()); 2242 2243 // Convert R2 to B X C. 2244 copyRel.convertVarKind(VarKind::Domain, 0, numBVars, VarKind::Range, 0); 2245 2246 // Intersect R2 to range of R1. 2247 intersectRange(IntegerPolyhedron(copyRel)); 2248 2249 // Project out B in R1. 2250 convertVarKind(VarKind::Range, 0, numBVars, VarKind::Local); 2251 } 2252 2253 void IntegerRelation::applyDomain(const IntegerRelation &rel) { 2254 inverse(); 2255 compose(rel); 2256 inverse(); 2257 } 2258 2259 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); } 2260 2261 void IntegerRelation::printSpace(raw_ostream &os) const { 2262 space.print(os); 2263 os << getNumConstraints() << " constraints\n"; 2264 } 2265 2266 void IntegerRelation::print(raw_ostream &os) const { 2267 assert(hasConsistentState()); 2268 printSpace(os); 2269 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 2270 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2271 os << atEq(i, j) << " "; 2272 } 2273 os << "= 0\n"; 2274 } 2275 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 2276 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2277 os << atIneq(i, j) << " "; 2278 } 2279 os << ">= 0\n"; 2280 } 2281 os << '\n'; 2282 } 2283 2284 void IntegerRelation::dump() const { print(llvm::errs()); } 2285 2286 unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos, 2287 unsigned num) { 2288 assert((kind != VarKind::Domain || num == 0) && 2289 "Domain has to be zero in a set"); 2290 return IntegerRelation::insertVar(kind, pos, num); 2291 } 2292 IntegerPolyhedron 2293 IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const { 2294 return IntegerPolyhedron(IntegerRelation::intersect(other)); 2295 } 2296 2297 PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const { 2298 return PresburgerSet(IntegerRelation::subtract(other)); 2299 } 2300