1 //===- Utils.cpp - General utilities for Presburger library ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility functions required by the Presburger Library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Presburger/Utils.h" 14 #include "mlir/Analysis/Presburger/IntegerRelation.h" 15 #include "mlir/Support/LogicalResult.h" 16 #include "mlir/Support/MathExtras.h" 17 18 using namespace mlir; 19 using namespace presburger; 20 21 /// Normalize a division's `dividend` and the `divisor` by their GCD. For 22 /// example: if the dividend and divisor are [2,0,4] and 4 respectively, 23 /// they get normalized to [1,0,2] and 2. 24 static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend, 25 unsigned &divisor) { 26 if (divisor == 0 || dividend.empty()) 27 return; 28 // We take the absolute value of dividend's coefficients to make sure that 29 // `gcd` is positive. 30 int64_t gcd = 31 llvm::greatestCommonDivisor(std::abs(dividend.front()), int64_t(divisor)); 32 33 // The reason for ignoring the constant term is as follows. 34 // For a division: 35 // floor((a + m.f(x))/(m.d)) 36 // It can be replaced by: 37 // floor((floor(a/m) + f(x))/d) 38 // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not 39 // influence the result of the floor division and thus, can be ignored. 40 for (size_t i = 1, m = dividend.size() - 1; i < m; i++) { 41 gcd = llvm::greatestCommonDivisor(std::abs(dividend[i]), gcd); 42 if (gcd == 1) 43 return; 44 } 45 46 // Normalize the dividend and the denominator. 47 std::transform(dividend.begin(), dividend.end(), dividend.begin(), 48 [gcd](int64_t &n) { return floorDiv(n, gcd); }); 49 divisor /= gcd; 50 } 51 52 /// Check if the pos^th variable can be represented as a division using upper 53 /// bound inequality at position `ubIneq` and lower bound inequality at position 54 /// `lbIneq`. 55 /// 56 /// Let `var` be the pos^th variable, then `var` is equivalent to 57 /// `expr floordiv divisor` if there are constraints of the form: 58 /// 0 <= expr - divisor * var <= divisor - 1 59 /// Rearranging, we have: 60 /// divisor * var - expr + (divisor - 1) >= 0 <-- Lower bound for 'var' 61 /// -divisor * var + expr >= 0 <-- Upper bound for 'var' 62 /// 63 /// For example: 64 /// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' 65 /// 32*k <= 16*i + j <-- Upper bound for 'k' 66 /// expr = 16*i + j, divisor = 32 67 /// k = ( 16*i + j ) floordiv 32 68 /// 69 /// 4q >= i + j - 2 <-- Lower bound for 'q' 70 /// 4q <= i + j + 1 <-- Upper bound for 'q' 71 /// expr = i + j + 1, divisor = 4 72 /// q = (i + j + 1) floordiv 4 73 // 74 /// This function also supports detecting divisions from bounds that are 75 /// strictly tighter than the division bounds described above, since tighter 76 /// bounds imply the division bounds. For example: 77 /// 4q - i - j + 2 >= 0 <-- Lower bound for 'q' 78 /// -4q + i + j >= 0 <-- Tight upper bound for 'q' 79 /// 80 /// To extract floor divisions with tighter bounds, we assume that that the 81 /// constraints are of the form: 82 /// c <= expr - divisior * var <= divisor - 1, where 0 <= c <= divisor - 1 83 /// Rearranging, we have: 84 /// divisor * var - expr + (divisor - 1) >= 0 <-- Lower bound for 'var' 85 /// -divisor * var + expr - c >= 0 <-- Upper bound for 'var' 86 /// 87 /// If successful, `expr` is set to dividend of the division and `divisor` is 88 /// set to the denominator of the division. The final division expression is 89 /// normalized by GCD. 90 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, 91 unsigned ubIneq, unsigned lbIneq, 92 MutableArrayRef<int64_t> expr, 93 unsigned &divisor) { 94 95 assert(pos <= cst.getNumVars() && "Invalid variable position"); 96 assert(ubIneq <= cst.getNumInequalities() && 97 "Invalid upper bound inequality position"); 98 assert(lbIneq <= cst.getNumInequalities() && 99 "Invalid upper bound inequality position"); 100 assert(expr.size() == cst.getNumCols() && "Invalid expression size"); 101 102 // Extract divisor from the lower bound. 103 divisor = cst.atIneq(lbIneq, pos); 104 105 // First, check if the constraints are opposite of each other except the 106 // constant term. 107 unsigned i = 0, e = 0; 108 for (i = 0, e = cst.getNumVars(); i < e; ++i) 109 if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i)) 110 break; 111 112 if (i < e) 113 return failure(); 114 115 // Then, check if the constant term is of the proper form. 116 // Due to the form of the upper/lower bound inequalities, the sum of their 117 // constants is `divisor - 1 - c`. From this, we can extract c: 118 int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + 119 cst.atIneq(ubIneq, cst.getNumCols() - 1); 120 int64_t c = divisor - 1 - constantSum; 121 122 // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also 123 // implictly checks that `divisor` is positive. 124 if (!(0 <= c && c <= divisor - 1)) // NOLINT 125 return failure(); 126 127 // The inequality pair can be used to extract the division. 128 // Set `expr` to the dividend of the division except the constant term, which 129 // is set below. 130 for (i = 0, e = cst.getNumVars(); i < e; ++i) 131 if (i != pos) 132 expr[i] = cst.atIneq(ubIneq, i); 133 134 // From the upper bound inequality's form, its constant term is equal to the 135 // constant term of `expr`, minus `c`. From this, 136 // constant term of `expr` = constant term of upper bound + `c`. 137 expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c; 138 normalizeDivisionByGCD(expr, divisor); 139 140 return success(); 141 } 142 143 /// Check if the pos^th variable can be represented as a division using 144 /// equality at position `eqInd`. 145 /// 146 /// For example: 147 /// 32*k == 16*i + j - 31 <-- `eqInd` for 'k' 148 /// expr = 16*i + j - 31, divisor = 32 149 /// k = (16*i + j - 31) floordiv 32 150 /// 151 /// If successful, `expr` is set to dividend of the division and `divisor` is 152 /// set to the denominator of the division. The final division expression is 153 /// normalized by GCD. 154 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, 155 unsigned eqInd, MutableArrayRef<int64_t> expr, 156 unsigned &divisor) { 157 158 assert(pos <= cst.getNumVars() && "Invalid variable position"); 159 assert(eqInd <= cst.getNumEqualities() && "Invalid equality position"); 160 assert(expr.size() == cst.getNumCols() && "Invalid expression size"); 161 162 // Extract divisor, the divisor can be negative and hence its sign information 163 // is stored in `signDiv` to reverse the sign of dividend's coefficients. 164 // Equality must involve the pos-th variable and hence `tempDiv` != 0. 165 int64_t tempDiv = cst.atEq(eqInd, pos); 166 if (tempDiv == 0) 167 return failure(); 168 int64_t signDiv = tempDiv < 0 ? -1 : 1; 169 170 // The divisor is always a positive integer. 171 divisor = tempDiv * signDiv; 172 173 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) 174 if (i != pos) 175 expr[i] = -signDiv * cst.atEq(eqInd, i); 176 177 expr.back() = -signDiv * cst.atEq(eqInd, cst.getNumCols() - 1); 178 normalizeDivisionByGCD(expr, divisor); 179 180 return success(); 181 } 182 183 // Returns `false` if the constraints depends on a variable for which an 184 // explicit representation has not been found yet, otherwise returns `true`. 185 static bool checkExplicitRepresentation(const IntegerRelation &cst, 186 ArrayRef<bool> foundRepr, 187 ArrayRef<int64_t> dividend, 188 unsigned pos) { 189 // Exit to avoid circular dependencies between divisions. 190 for (unsigned c = 0, e = cst.getNumVars(); c < e; ++c) { 191 if (c == pos) 192 continue; 193 194 if (!foundRepr[c] && dividend[c] != 0) { 195 // Expression can't be constructed as it depends on a yet unknown 196 // variable. 197 // 198 // TODO: Visit/compute the variables in an order so that this doesn't 199 // happen. More complex but much more efficient. 200 return false; 201 } 202 } 203 204 return true; 205 } 206 207 /// Check if the pos^th variable can be expressed as a floordiv of an affine 208 /// function of other variables (where the divisor is a positive constant). 209 /// `foundRepr` contains a boolean for each variable indicating if the 210 /// explicit representation for that variable has already been computed. 211 /// Returns the `MaybeLocalRepr` struct which contains the indices of the 212 /// constraints that can be expressed as a floordiv of an affine function. If 213 /// the representation could be computed, `dividend` and `denominator` are set. 214 /// If the representation could not be computed, the kind attribute in 215 /// `MaybeLocalRepr` is set to None. 216 MaybeLocalRepr presburger::computeSingleVarRepr( 217 const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos, 218 MutableArrayRef<int64_t> dividend, unsigned &divisor) { 219 assert(pos < cst.getNumVars() && "invalid position"); 220 assert(foundRepr.size() == cst.getNumVars() && 221 "Size of foundRepr does not match total number of variables"); 222 assert(dividend.size() == cst.getNumCols() && "Invalid dividend size"); 223 224 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 225 cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, &eqIndices); 226 MaybeLocalRepr repr{}; 227 228 for (unsigned ubPos : ubIndices) { 229 for (unsigned lbPos : lbIndices) { 230 // Attempt to get divison representation from ubPos, lbPos. 231 if (failed(getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))) 232 continue; 233 234 if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos)) 235 continue; 236 237 repr.kind = ReprKind::Inequality; 238 repr.repr.inequalityPair = {ubPos, lbPos}; 239 return repr; 240 } 241 } 242 for (unsigned eqPos : eqIndices) { 243 // Attempt to get divison representation from eqPos. 244 if (failed(getDivRepr(cst, pos, eqPos, dividend, divisor))) 245 continue; 246 247 if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos)) 248 continue; 249 250 repr.kind = ReprKind::Equality; 251 repr.repr.equalityIdx = eqPos; 252 return repr; 253 } 254 return repr; 255 } 256 257 llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len, 258 unsigned setOffset, 259 unsigned numSet) { 260 llvm::SmallBitVector vec(len, false); 261 vec.set(setOffset, setOffset + numSet); 262 return vec; 263 } 264 265 void presburger::mergeLocalVars( 266 IntegerRelation &relA, IntegerRelation &relB, 267 llvm::function_ref<bool(unsigned i, unsigned j)> merge) { 268 assert(relA.getSpace().isCompatible(relB.getSpace()) && 269 "Spaces should be compatible."); 270 271 // Merge local vars of relA and relB without using division information, 272 // i.e. append local vars of `relB` to `relA` and insert local vars of `relA` 273 // to `relB` at start of its local vars. 274 unsigned initLocals = relA.getNumLocalVars(); 275 relA.insertVar(VarKind::Local, relA.getNumLocalVars(), 276 relB.getNumLocalVars()); 277 relB.insertVar(VarKind::Local, 0, initLocals); 278 279 // Get division representations from each rel. 280 DivisionRepr divsA = relA.getLocalReprs(); 281 DivisionRepr divsB = relB.getLocalReprs(); 282 283 for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) { 284 divsA.setDividend(i, divsB.getDividend(i)); 285 divsA.getDenom(i) = divsB.getDenom(i); 286 } 287 288 // Remove duplicate divisions from divsA. The removing duplicate divisions 289 // call, calls `merge` to effectively merge divisions in relA and relB. 290 divsA.removeDuplicateDivs(merge); 291 } 292 293 SmallVector<int64_t, 8> presburger::getDivUpperBound(ArrayRef<int64_t> dividend, 294 int64_t divisor, 295 unsigned localVarIdx) { 296 assert(dividend[localVarIdx] == 0 && 297 "Local to be set to division must have zero coeff!"); 298 SmallVector<int64_t, 8> ineq(dividend.begin(), dividend.end()); 299 ineq[localVarIdx] = -divisor; 300 return ineq; 301 } 302 303 SmallVector<int64_t, 8> presburger::getDivLowerBound(ArrayRef<int64_t> dividend, 304 int64_t divisor, 305 unsigned localVarIdx) { 306 assert(dividend[localVarIdx] == 0 && 307 "Local to be set to division must have zero coeff!"); 308 SmallVector<int64_t, 8> ineq(dividend.size()); 309 std::transform(dividend.begin(), dividend.end(), ineq.begin(), 310 std::negate<int64_t>()); 311 ineq[localVarIdx] = divisor; 312 ineq.back() += divisor - 1; 313 return ineq; 314 } 315 316 int64_t presburger::gcdRange(ArrayRef<int64_t> range) { 317 int64_t gcd = 0; 318 for (int64_t elem : range) { 319 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(elem)); 320 if (gcd == 1) 321 return gcd; 322 } 323 return gcd; 324 } 325 326 int64_t presburger::normalizeRange(MutableArrayRef<int64_t> range) { 327 int64_t gcd = gcdRange(range); 328 if (gcd == 0 || gcd == 1) 329 return gcd; 330 for (int64_t &elem : range) 331 elem /= gcd; 332 return gcd; 333 } 334 335 void presburger::normalizeDiv(MutableArrayRef<int64_t> num, int64_t &denom) { 336 assert(denom > 0 && "denom must be positive!"); 337 int64_t gcd = llvm::greatestCommonDivisor(gcdRange(num), denom); 338 for (int64_t &coeff : num) 339 coeff /= gcd; 340 denom /= gcd; 341 } 342 343 SmallVector<int64_t, 8> presburger::getNegatedCoeffs(ArrayRef<int64_t> coeffs) { 344 SmallVector<int64_t, 8> negatedCoeffs; 345 negatedCoeffs.reserve(coeffs.size()); 346 for (int64_t coeff : coeffs) 347 negatedCoeffs.emplace_back(-coeff); 348 return negatedCoeffs; 349 } 350 351 SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) { 352 SmallVector<int64_t, 8> coeffs; 353 coeffs.reserve(ineq.size()); 354 for (int64_t coeff : ineq) 355 coeffs.emplace_back(-coeff); 356 --coeffs.back(); 357 return coeffs; 358 } 359 360 void DivisionRepr::removeDuplicateDivs( 361 llvm::function_ref<bool(unsigned i, unsigned j)> merge) { 362 363 // Find and merge duplicate divisions. 364 // TODO: Add division normalization to support divisions that differ by 365 // a constant. 366 // TODO: Add division ordering such that a division representation for local 367 // variable at position `i` only depends on local variables at position < 368 // `i`. This would make sure that all divisions depending on other local 369 // variables that can be merged, are merged. 370 for (unsigned i = 0; i < getNumDivs(); ++i) { 371 // Check if a division representation exists for the `i^th` local var. 372 if (denoms[i] == 0) 373 continue; 374 // Check if a division exists which is a duplicate of the division at `i`. 375 for (unsigned j = i + 1; j < getNumDivs(); ++j) { 376 // Check if a division representation exists for the `j^th` local var. 377 if (denoms[j] == 0) 378 continue; 379 // Check if the denominators match. 380 if (denoms[i] != denoms[j]) 381 continue; 382 // Check if the representations are equal. 383 if (dividends.getRow(i) != dividends.getRow(j)) 384 continue; 385 386 // Merge divisions at position `j` into division at position `i`. If 387 // merge fails, do not merge these divs. 388 bool mergeResult = merge(i, j); 389 if (!mergeResult) 390 continue; 391 392 // Update division information to reflect merging. 393 unsigned divOffset = getDivOffset(); 394 dividends.addToColumn(divOffset + j, divOffset + i, /*scale=*/1); 395 dividends.removeColumn(divOffset + j); 396 dividends.removeRow(j); 397 denoms.erase(denoms.begin() + j); 398 399 // Since `j` can never be zero, we do not need to worry about overflows. 400 --j; 401 } 402 } 403 } 404 405 void DivisionRepr::print(raw_ostream &os) const { 406 os << "Dividends:\n"; 407 dividends.print(os); 408 os << "Denominators\n"; 409 for (unsigned i = 0, e = denoms.size(); i < e; ++i) 410 os << denoms[i] << " "; 411 os << "\n"; 412 } 413 414 void DivisionRepr::dump() const { print(llvm::errs()); } 415 416 SmallVector<MPInt, 8> presburger::getMPIntVec(ArrayRef<int64_t> range) { 417 SmallVector<MPInt, 8> result(range.size()); 418 std::transform(range.begin(), range.end(), result.begin(), mpintFromInt64); 419 return result; 420 } 421 422 SmallVector<int64_t, 8> presburger::getInt64Vec(ArrayRef<MPInt> range) { 423 SmallVector<int64_t, 8> result(range.size()); 424 std::transform(range.begin(), range.end(), result.begin(), int64FromMPInt); 425 return result; 426 } 427