1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Structures for affine/polyhedral analysis of affine dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 14 #include "mlir/Analysis/Presburger/LinearTransform.h" 15 #include "mlir/Analysis/Presburger/Simplex.h" 16 #include "mlir/Analysis/Presburger/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Support/MathExtras.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #define DEBUG_TYPE "affine-structures" 31 32 using namespace mlir; 33 using namespace presburger; 34 35 namespace { 36 37 // See comments for SimpleAffineExprFlattener. 38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording 39 // constraint information associated with mod's, floordiv's, and ceildiv's 40 // in FlatAffineValueConstraints 'localVarCst'. 41 struct AffineExprFlattener : public SimpleAffineExprFlattener { 42 public: 43 // Constraints connecting newly introduced local variables (for mod's and 44 // div's) to existing (dimensional and symbolic) ones. These are always 45 // inequalities. 46 IntegerPolyhedron localVarCst; 47 48 AffineExprFlattener(unsigned nDims, unsigned nSymbols) 49 : SimpleAffineExprFlattener(nDims, nSymbols), 50 localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {} 51 52 private: 53 // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr). 54 // The local variable added is always a floordiv of a pure add/mul affine 55 // function of other variables, coefficients of which are specified in 56 // `dividend' and with respect to the positive constant `divisor'. localExpr 57 // is the simplified tree expression (AffineExpr) corresponding to the 58 // quantifier. 59 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 60 AffineExpr localExpr) override { 61 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 62 // Update localVarCst. 63 localVarCst.addLocalFloorDiv(dividend, divisor); 64 } 65 }; 66 67 } // namespace 68 69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be 70 // flattened (i.e., semi-affine expressions not handled yet). 71 static LogicalResult 72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, 73 unsigned numSymbols, 74 std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 75 FlatAffineValueConstraints *localVarCst) { 76 if (exprs.empty()) { 77 localVarCst->reset(numDims, numSymbols); 78 return success(); 79 } 80 81 AffineExprFlattener flattener(numDims, numSymbols); 82 // Use the same flattener to simplify each expression successively. This way 83 // local variables / expressions are shared. 84 for (auto expr : exprs) { 85 if (!expr.isPureAffine()) 86 return failure(); 87 88 flattener.walkPostOrder(expr); 89 } 90 91 assert(flattener.operandExprStack.size() == exprs.size()); 92 flattenedExprs->clear(); 93 flattenedExprs->assign(flattener.operandExprStack.begin(), 94 flattener.operandExprStack.end()); 95 96 if (localVarCst) 97 localVarCst->clearAndCopyFrom(flattener.localVarCst); 98 99 return success(); 100 } 101 102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 103 // be flattened (semi-affine expressions not handled yet). 104 LogicalResult 105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, 106 unsigned numSymbols, 107 SmallVectorImpl<int64_t> *flattenedExpr, 108 FlatAffineValueConstraints *localVarCst) { 109 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 110 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, 111 &flattenedExprs, localVarCst); 112 *flattenedExpr = flattenedExprs[0]; 113 return ret; 114 } 115 116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 117 /// flattened (i.e., semi-affine expressions not handled yet). 118 LogicalResult mlir::getFlattenedAffineExprs( 119 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 120 FlatAffineValueConstraints *localVarCst) { 121 if (map.getNumResults() == 0) { 122 localVarCst->reset(map.getNumDims(), map.getNumSymbols()); 123 return success(); 124 } 125 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), 126 map.getNumSymbols(), flattenedExprs, 127 localVarCst); 128 } 129 130 LogicalResult mlir::getFlattenedAffineExprs( 131 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 132 FlatAffineValueConstraints *localVarCst) { 133 if (set.getNumConstraints() == 0) { 134 localVarCst->reset(set.getNumDims(), set.getNumSymbols()); 135 return success(); 136 } 137 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 138 set.getNumSymbols(), flattenedExprs, 139 localVarCst); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // FlatAffineConstraints / FlatAffineValueConstraints. 144 //===----------------------------------------------------------------------===// 145 146 std::unique_ptr<FlatAffineValueConstraints> 147 FlatAffineValueConstraints::clone() const { 148 return std::make_unique<FlatAffineValueConstraints>(*this); 149 } 150 151 // Construct from an IntegerSet. 152 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set) 153 : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(), 154 set.getNumDims() + set.getNumSymbols() + 1, 155 PresburgerSpace::getSetSpace(set.getNumDims(), 156 set.getNumSymbols(), 157 /*numLocals=*/0)) { 158 159 // Resize values. 160 values.resize(getNumDimAndSymbolVars(), None); 161 162 // Flatten expressions and add them to the constraint system. 163 std::vector<SmallVector<int64_t, 8>> flatExprs; 164 FlatAffineValueConstraints localVarCst; 165 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 166 assert(false && "flattening unimplemented for semi-affine integer sets"); 167 return; 168 } 169 assert(flatExprs.size() == set.getNumConstraints()); 170 insertVar(VarKind::Local, getNumVarKind(VarKind::Local), 171 /*num=*/localVarCst.getNumLocalVars()); 172 173 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 174 const auto &flatExpr = flatExprs[i]; 175 assert(flatExpr.size() == getNumCols()); 176 if (set.getEqFlags()[i]) { 177 addEquality(flatExpr); 178 } else { 179 addInequality(flatExpr); 180 } 181 } 182 // Add the other constraints involving local vars from flattening. 183 append(localVarCst); 184 } 185 186 // Construct a hyperrectangular constraint set from ValueRanges that represent 187 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are 188 // expected to match one to one. The order of variables and constraints is: 189 // 190 // ivs | lbs | ubs | eq/ineq 191 // ----+-----+-----+--------- 192 // 1 -1 0 >= 0 193 // ----+-----+-----+--------- 194 // -1 0 1 >= 0 195 // 196 // All dimensions as set as VarKind::SetDim. 197 FlatAffineValueConstraints 198 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs, 199 ValueRange ubs) { 200 FlatAffineValueConstraints res; 201 unsigned nIvs = ivs.size(); 202 assert(nIvs == lbs.size() && "expected as many lower bounds as ivs"); 203 assert(nIvs == ubs.size() && "expected as many upper bounds as ivs"); 204 205 if (nIvs == 0) 206 return res; 207 208 res.appendDimVar(ivs); 209 unsigned lbsStart = res.appendDimVar(lbs); 210 unsigned ubsStart = res.appendDimVar(ubs); 211 212 MLIRContext *ctx = ivs.front().getContext(); 213 for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) { 214 // iv - lb >= 0 215 AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 216 getAffineDimExpr(lbsStart + ivIdx, ctx)); 217 if (failed(res.addBound(BoundType::LB, ivIdx, lb))) 218 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 219 // -iv + ub >= 0 220 AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 221 getAffineDimExpr(ubsStart + ivIdx, ctx)); 222 if (failed(res.addBound(BoundType::UB, ivIdx, ub))) 223 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 224 } 225 return res; 226 } 227 228 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities, 229 unsigned numReservedEqualities, 230 unsigned newNumReservedCols, 231 unsigned newNumDims, 232 unsigned newNumSymbols, 233 unsigned newNumLocals) { 234 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 235 "minimum 1 column"); 236 *this = FlatAffineValueConstraints(numReservedInequalities, 237 numReservedEqualities, newNumReservedCols, 238 newNumDims, newNumSymbols, newNumLocals); 239 } 240 241 void FlatAffineValueConstraints::reset(unsigned newNumDims, 242 unsigned newNumSymbols, 243 unsigned newNumLocals) { 244 reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, 245 /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1, 246 newNumDims, newNumSymbols, newNumLocals); 247 } 248 249 void FlatAffineValueConstraints::reset( 250 unsigned numReservedInequalities, unsigned numReservedEqualities, 251 unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, 252 unsigned newNumLocals, ArrayRef<Value> valArgs) { 253 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 254 "minimum 1 column"); 255 SmallVector<Optional<Value>, 8> newVals; 256 if (!valArgs.empty()) 257 newVals.assign(valArgs.begin(), valArgs.end()); 258 259 *this = FlatAffineValueConstraints( 260 numReservedInequalities, numReservedEqualities, newNumReservedCols, 261 newNumDims, newNumSymbols, newNumLocals, newVals); 262 } 263 264 void FlatAffineValueConstraints::reset(unsigned newNumDims, 265 unsigned newNumSymbols, 266 unsigned newNumLocals, 267 ArrayRef<Value> valArgs) { 268 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, 269 newNumSymbols, newNumLocals, valArgs); 270 } 271 272 unsigned FlatAffineValueConstraints::appendDimVar(ValueRange vals) { 273 unsigned pos = getNumDimVars(); 274 insertVar(VarKind::SetDim, pos, vals); 275 return pos; 276 } 277 278 unsigned FlatAffineValueConstraints::appendSymbolVar(ValueRange vals) { 279 unsigned pos = getNumSymbolVars(); 280 insertVar(VarKind::Symbol, pos, vals); 281 return pos; 282 } 283 284 unsigned FlatAffineValueConstraints::insertDimVar(unsigned pos, 285 ValueRange vals) { 286 return insertVar(VarKind::SetDim, pos, vals); 287 } 288 289 unsigned FlatAffineValueConstraints::insertSymbolVar(unsigned pos, 290 ValueRange vals) { 291 return insertVar(VarKind::Symbol, pos, vals); 292 } 293 294 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos, 295 unsigned num) { 296 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 297 298 if (kind != VarKind::Local) { 299 values.insert(values.begin() + absolutePos, num, None); 300 assert(values.size() == getNumDimAndSymbolVars()); 301 } 302 303 return absolutePos; 304 } 305 306 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos, 307 ValueRange vals) { 308 assert(!vals.empty() && "expected ValueRange with Values."); 309 assert(kind != VarKind::Local && 310 "values cannot be attached to local variables."); 311 unsigned num = vals.size(); 312 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 313 314 // If a Value is provided, insert it; otherwise use None. 315 for (unsigned i = 0; i < num; ++i) 316 values.insert(values.begin() + absolutePos + i, 317 vals[i] ? Optional<Value>(vals[i]) : None); 318 319 assert(values.size() == getNumDimAndSymbolVars()); 320 return absolutePos; 321 } 322 323 bool FlatAffineValueConstraints::hasValues() const { 324 return llvm::find_if(values, [](Optional<Value> var) { 325 return var.hasValue(); 326 }) != values.end(); 327 } 328 329 /// Checks if two constraint systems are in the same space, i.e., if they are 330 /// associated with the same set of variables, appearing in the same order. 331 static bool areVarsAligned(const FlatAffineValueConstraints &a, 332 const FlatAffineValueConstraints &b) { 333 return a.getNumDimVars() == b.getNumDimVars() && 334 a.getNumSymbolVars() == b.getNumSymbolVars() && 335 a.getNumVars() == b.getNumVars() && 336 a.getMaybeValues().equals(b.getMaybeValues()); 337 } 338 339 /// Calls areVarsAligned to check if two constraint systems have the same set 340 /// of variables in the same order. 341 bool FlatAffineValueConstraints::areVarsAlignedWithOther( 342 const FlatAffineValueConstraints &other) { 343 return areVarsAligned(*this, other); 344 } 345 346 /// Checks if the SSA values associated with `cst`'s variables in range 347 /// [start, end) are unique. 348 static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( 349 const FlatAffineValueConstraints &cst, unsigned start, unsigned end) { 350 351 assert(start <= cst.getNumDimAndSymbolVars() && 352 "Start position out of bounds"); 353 assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds"); 354 355 if (start >= end) 356 return true; 357 358 SmallPtrSet<Value, 8> uniqueVars; 359 ArrayRef<Optional<Value>> maybeValues = 360 cst.getMaybeValues().slice(start, end - start); 361 for (Optional<Value> val : maybeValues) { 362 if (val && !uniqueVars.insert(*val).second) 363 return false; 364 } 365 return true; 366 } 367 368 /// Checks if the SSA values associated with `cst`'s variables are unique. 369 static bool LLVM_ATTRIBUTE_UNUSED 370 areVarsUnique(const FlatAffineValueConstraints &cst) { 371 return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); 372 } 373 374 /// Checks if the SSA values associated with `cst`'s variables of kind `kind` 375 /// are unique. 376 static bool LLVM_ATTRIBUTE_UNUSED 377 areVarsUnique(const FlatAffineValueConstraints &cst, VarKind kind) { 378 379 if (kind == VarKind::SetDim) 380 return areVarsUnique(cst, 0, cst.getNumDimVars()); 381 if (kind == VarKind::Symbol) 382 return areVarsUnique(cst, cst.getNumDimVars(), 383 cst.getNumDimAndSymbolVars()); 384 llvm_unreachable("Unexpected VarKind"); 385 } 386 387 /// Merge and align the variables of A and B starting at 'offset', so that 388 /// both constraint systems get the union of the contained variables that is 389 /// dimension-wise and symbol-wise unique; both constraint systems are updated 390 /// so that they have the union of all variables, with A's original 391 /// variables appearing first followed by any of B's variables that didn't 392 /// appear in A. Local variables in B that have the same division 393 /// representation as local variables in A are merged into one. 394 // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) 395 // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 396 static void mergeAndAlignVars(unsigned offset, FlatAffineValueConstraints *a, 397 FlatAffineValueConstraints *b) { 398 assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars()); 399 // A merge/align isn't meaningful if a cst's vars aren't distinct. 400 assert(areVarsUnique(*a) && "A's values aren't unique"); 401 assert(areVarsUnique(*b) && "B's values aren't unique"); 402 403 assert(std::all_of(a->getMaybeValues().begin() + offset, 404 a->getMaybeValues().end(), 405 [](Optional<Value> var) { return var.hasValue(); })); 406 407 assert(std::all_of(b->getMaybeValues().begin() + offset, 408 b->getMaybeValues().end(), 409 [](Optional<Value> var) { return var.hasValue(); })); 410 411 SmallVector<Value, 4> aDimValues; 412 a->getValues(offset, a->getNumDimVars(), &aDimValues); 413 414 { 415 // Merge dims from A into B. 416 unsigned d = offset; 417 for (auto aDimValue : aDimValues) { 418 unsigned loc; 419 if (b->findVar(aDimValue, &loc)) { 420 assert(loc >= offset && "A's dim appears in B's aligned range"); 421 assert(loc < b->getNumDimVars() && 422 "A's dim appears in B's non-dim position"); 423 b->swapVar(d, loc); 424 } else { 425 b->insertDimVar(d, aDimValue); 426 } 427 d++; 428 } 429 // Dimensions that are in B, but not in A, are added at the end. 430 for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) { 431 a->appendDimVar(b->getValue(t)); 432 } 433 assert(a->getNumDimVars() == b->getNumDimVars() && 434 "expected same number of dims"); 435 } 436 437 // Merge and align symbols of A and B 438 a->mergeSymbolVars(*b); 439 // Merge and align locals of A and B 440 a->mergeLocalVars(*b); 441 442 assert(areVarsAligned(*a, *b) && "IDs expected to be aligned"); 443 } 444 445 // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'. 446 void FlatAffineValueConstraints::mergeAndAlignVarsWithOther( 447 unsigned offset, FlatAffineValueConstraints *other) { 448 mergeAndAlignVars(offset, this, other); 449 } 450 451 LogicalResult 452 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) { 453 return composeMatchingMap( 454 computeAlignedMap(vMap->getAffineMap(), vMap->getOperands())); 455 } 456 457 // Similar to `composeMap` except that no Values need be associated with the 458 // constraint system nor are they looked at -- the dimensions and symbols of 459 // `other` are expected to correspond 1:1 to `this` system. 460 LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) { 461 assert(other.getNumDims() == getNumDimVars() && "dim mismatch"); 462 assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 463 464 std::vector<SmallVector<int64_t, 8>> flatExprs; 465 if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) 466 return failure(); 467 assert(flatExprs.size() == other.getNumResults()); 468 469 // Add dimensions corresponding to the map's results. 470 insertDimVar(/*pos=*/0, /*num=*/other.getNumResults()); 471 472 // We add one equality for each result connecting the result dim of the map to 473 // the other variables. 474 // E.g.: if the expression is 16*i0 + i1, and this is the r^th 475 // iteration/result of the value map, we are adding the equality: 476 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we 477 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 478 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 479 const auto &flatExpr = flatExprs[r]; 480 assert(flatExpr.size() >= other.getNumInputs() + 1); 481 482 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 483 // Set the coefficient for this result to one. 484 eqToAdd[r] = 1; 485 486 // Dims and symbols. 487 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { 488 // Negate `eq[r]` since the newly added dimension will be set to this one. 489 eqToAdd[e + i] = -flatExpr[i]; 490 } 491 // Local columns of `eq` are at the beginning. 492 unsigned j = getNumDimVars() + getNumSymbolVars(); 493 unsigned end = flatExpr.size() - 1; 494 for (unsigned i = other.getNumInputs(); i < end; i++, j++) { 495 eqToAdd[j] = -flatExpr[i]; 496 } 497 498 // Constant term. 499 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 500 501 // Add the equality connecting the result of the map to this constraint set. 502 addEquality(eqToAdd); 503 } 504 505 return success(); 506 } 507 508 // Turn a symbol into a dimension. 509 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value value) { 510 unsigned pos; 511 if (cst->findVar(value, &pos) && pos >= cst->getNumDimVars() && 512 pos < cst->getNumDimAndSymbolVars()) { 513 cst->swapVar(pos, cst->getNumDimVars()); 514 cst->setDimSymbolSeparation(cst->getNumSymbolVars() - 1); 515 } 516 } 517 518 /// Merge and align symbols of `this` and `other` such that both get union of 519 /// of symbols that are unique. Symbols in `this` and `other` should be 520 /// unique. Symbols with Value as `None` are considered to be inequal to all 521 /// other symbols. 522 void FlatAffineValueConstraints::mergeSymbolVars( 523 FlatAffineValueConstraints &other) { 524 525 assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique"); 526 assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique"); 527 528 SmallVector<Value, 4> aSymValues; 529 getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues); 530 531 // Merge symbols: merge symbols into `other` first from `this`. 532 unsigned s = other.getNumDimVars(); 533 for (Value aSymValue : aSymValues) { 534 unsigned loc; 535 // If the var is a symbol in `other`, then align it, otherwise assume that 536 // it is a new symbol 537 if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() && 538 loc < other.getNumDimAndSymbolVars()) 539 other.swapVar(s, loc); 540 else 541 other.insertSymbolVar(s - other.getNumDimVars(), aSymValue); 542 s++; 543 } 544 545 // Symbols that are in other, but not in this, are added at the end. 546 for (unsigned t = other.getNumDimVars() + getNumSymbolVars(), 547 e = other.getNumDimAndSymbolVars(); 548 t < e; t++) 549 insertSymbolVar(getNumSymbolVars(), other.getValue(t)); 550 551 assert(getNumSymbolVars() == other.getNumSymbolVars() && 552 "expected same number of symbols"); 553 assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique"); 554 assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique"); 555 } 556 557 // Changes all symbol variables which are loop IVs to dim variables. 558 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { 559 // Gather all symbols which are loop IVs. 560 SmallVector<Value, 4> loopIVs; 561 for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) { 562 if (hasValue(i) && getForInductionVarOwner(getValue(i))) 563 loopIVs.push_back(getValue(i)); 564 } 565 // Turn each symbol in 'loopIVs' into a dim variable. 566 for (auto iv : loopIVs) { 567 turnSymbolIntoDim(this, iv); 568 } 569 } 570 571 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { 572 if (containsVar(val)) 573 return; 574 575 // Caller is expected to fully compose map/operands if necessary. 576 assert((isTopLevelValue(val) || isForInductionVar(val)) && 577 "non-terminal symbol / loop IV expected"); 578 // Outer loop IVs could be used in forOp's bounds. 579 if (auto loop = getForInductionVarOwner(val)) { 580 appendDimVar(val); 581 if (failed(this->addAffineForOpDomain(loop))) 582 LLVM_DEBUG( 583 loop.emitWarning("failed to add domain info to constraint system")); 584 return; 585 } 586 // Add top level symbol. 587 appendSymbolVar(val); 588 // Check if the symbol is a constant. 589 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) 590 addBound(BoundType::EQ, val, constOp.value()); 591 } 592 593 LogicalResult 594 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { 595 unsigned pos; 596 // Pre-condition for this method. 597 if (!findVar(forOp.getInductionVar(), &pos)) { 598 assert(false && "Value not found"); 599 return failure(); 600 } 601 602 int64_t step = forOp.getStep(); 603 if (step != 1) { 604 if (!forOp.hasConstantLowerBound()) 605 LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated")); 606 else { 607 // Add constraints for the stride. 608 // (iv - lb) % step = 0 can be written as: 609 // (iv - lb) - step * q = 0 where q = (iv - lb) / step. 610 // Add local variable 'q' and add the above equality. 611 // The first constraint is q = (iv - lb) floordiv step 612 SmallVector<int64_t, 8> dividend(getNumCols(), 0); 613 int64_t lb = forOp.getConstantLowerBound(); 614 dividend[pos] = 1; 615 dividend.back() -= lb; 616 addLocalFloorDiv(dividend, step); 617 // Second constraint: (iv - lb) - step * q = 0. 618 SmallVector<int64_t, 8> eq(getNumCols(), 0); 619 eq[pos] = 1; 620 eq.back() -= lb; 621 // For the local var just added above. 622 eq[getNumCols() - 2] = -step; 623 addEquality(eq); 624 } 625 } 626 627 if (forOp.hasConstantLowerBound()) { 628 addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); 629 } else { 630 // Non-constant lower bound case. 631 if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), 632 forOp.getLowerBoundOperands()))) 633 return failure(); 634 } 635 636 if (forOp.hasConstantUpperBound()) { 637 addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); 638 return success(); 639 } 640 // Non-constant upper bound case. 641 return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), 642 forOp.getUpperBoundOperands()); 643 } 644 645 LogicalResult 646 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, 647 ArrayRef<AffineMap> ubMaps, 648 ArrayRef<Value> operands) { 649 assert(lbMaps.size() == ubMaps.size()); 650 assert(lbMaps.size() <= getNumDimVars()); 651 652 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 653 AffineMap lbMap = lbMaps[i]; 654 AffineMap ubMap = ubMaps[i]; 655 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 656 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 657 658 // Check if this slice is just an equality along this dimension. If so, 659 // retrieve the existing loop it equates to and add it to the system. 660 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 661 ubMap.getNumResults() == 1 && 662 lbMap.getResult(0) + 1 == ubMap.getResult(0) && 663 // The condition above will be true for maps describing a single 664 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). 665 // Make sure we skip those cases by checking that the lb result is not 666 // just a constant. 667 !lbMap.getResult(0).isa<AffineConstantExpr>()) { 668 // Limited support: we expect the lb result to be just a loop dimension. 669 // Not supported otherwise for now. 670 AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>(); 671 if (!result) 672 return failure(); 673 674 AffineForOp loop = 675 getForInductionVarOwner(operands[result.getPosition()]); 676 if (!loop) 677 return failure(); 678 679 if (failed(addAffineForOpDomain(loop))) 680 return failure(); 681 continue; 682 } 683 684 // This slice refers to a loop that doesn't exist in the IR yet. Add its 685 // bounds to the system assuming its dimension variable position is the 686 // same as the position of the loop in the loop nest. 687 if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands))) 688 return failure(); 689 if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands))) 690 return failure(); 691 } 692 return success(); 693 } 694 695 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { 696 // Create the base constraints from the integer set attached to ifOp. 697 FlatAffineValueConstraints cst(ifOp.getIntegerSet()); 698 699 // Bind vars in the constraints to ifOp operands. 700 SmallVector<Value, 4> operands = ifOp.getOperands(); 701 cst.setValues(0, cst.getNumDimAndSymbolVars(), operands); 702 703 // Merge the constraints from ifOp to the current domain. We need first merge 704 // and align the IDs from both constraints, and then append the constraints 705 // from the ifOp into the current one. 706 mergeAndAlignVarsWithOther(0, &cst); 707 append(cst); 708 } 709 710 bool FlatAffineValueConstraints::hasConsistentState() const { 711 return IntegerPolyhedron::hasConsistentState() && 712 values.size() == getNumDimAndSymbolVars(); 713 } 714 715 void FlatAffineValueConstraints::removeVarRange(VarKind kind, unsigned varStart, 716 unsigned varLimit) { 717 IntegerPolyhedron::removeVarRange(kind, varStart, varLimit); 718 unsigned offset = getVarKindOffset(kind); 719 720 if (kind != VarKind::Local) { 721 values.erase(values.begin() + varStart + offset, 722 values.begin() + varLimit + offset); 723 } 724 } 725 726 // Determine whether the variable at 'pos' (say var_r) can be expressed as 727 // modulo of another known variable (say var_n) w.r.t a constant. For example, 728 // if the following constraints hold true: 729 // ``` 730 // 0 <= var_r <= divisor - 1 731 // var_n - (divisor * q_expr) = var_r 732 // ``` 733 // where `var_n` is a known variable (called dividend), and `q_expr` is an 734 // `AffineExpr` (called the quotient expression), `var_r` can be written as: 735 // 736 // `var_r = var_n mod divisor`. 737 // 738 // Additionally, in a special case of the above constaints where `q_expr` is an 739 // variable itself that is not yet known (say `var_q`), it can be written as a 740 // floordiv in the following way: 741 // 742 // `var_q = var_n floordiv divisor`. 743 // 744 // Returns true if the above mod or floordiv are detected, updating 'memo' with 745 // these new expressions. Returns false otherwise. 746 static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos, 747 int64_t lbConst, int64_t ubConst, 748 SmallVectorImpl<AffineExpr> &memo, 749 MLIRContext *context) { 750 assert(pos < cst.getNumVars() && "invalid position"); 751 752 // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can 753 // be determined. 754 if (lbConst != 0 || ubConst < 1) 755 return false; 756 int64_t divisor = ubConst + 1; 757 758 // Check for the aforementioned conditions in each equality. 759 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); 760 curEquality < numEqualities; curEquality++) { 761 int64_t coefficientAtPos = cst.atEq(curEquality, pos); 762 // If current equality does not involve `var_r`, continue to the next 763 // equality. 764 if (coefficientAtPos == 0) 765 continue; 766 767 // Constant term should be 0 in this equality. 768 if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0) 769 continue; 770 771 // Traverse through the equality and construct the dividend expression 772 // `dividendExpr`, to contain all the variables which are known and are 773 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the 774 // `dividendExpr` gets simplified into a single variable `var_n` discussed 775 // above. 776 auto dividendExpr = getAffineConstantExpr(0, context); 777 778 // Track the terms that go into quotient expression, later used to detect 779 // additional floordiv. 780 unsigned quotientCount = 0; 781 int quotientPosition = -1; 782 int quotientSign = 1; 783 784 // Consider each term in the current equality. 785 unsigned curVar, e; 786 for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) { 787 // Ignore var_r. 788 if (curVar == pos) 789 continue; 790 int64_t coefficientOfCurVar = cst.atEq(curEquality, curVar); 791 // Ignore vars that do not contribute to the current equality. 792 if (coefficientOfCurVar == 0) 793 continue; 794 // Check if the current var goes into the quotient expression. 795 if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) { 796 quotientCount++; 797 quotientPosition = curVar; 798 quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1; 799 continue; 800 } 801 // Variables that are part of dividendExpr should be known. 802 if (!memo[curVar]) 803 break; 804 // Append the current variable to the dividend expression. 805 dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar; 806 } 807 808 // Can't construct expression as it depends on a yet uncomputed var. 809 if (curVar < e) 810 continue; 811 812 // Express `var_r` in terms of the other vars collected so far. 813 if (coefficientAtPos > 0) 814 dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); 815 else 816 dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); 817 818 // Simplify the expression. 819 dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(), 820 cst.getNumSymbolVars()); 821 // Only if the final dividend expression is just a single var (which we call 822 // `var_n`), we can proceed. 823 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it 824 // to dims themselves. 825 auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>(); 826 if (!dimExpr) 827 continue; 828 829 // Express `var_r` as `var_n % divisor` and store the expression in `memo`. 830 if (quotientCount >= 1) { 831 auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB, 832 dimExpr.getPosition()); 833 // If `var_n` has an upperbound that is less than the divisor, mod can be 834 // eliminated altogether. 835 if (ub && *ub < divisor) 836 memo[pos] = dimExpr; 837 else 838 memo[pos] = dimExpr % divisor; 839 // If a unique quotient `var_q` was seen, it can be expressed as 840 // `var_n floordiv divisor`. 841 if (quotientCount == 1 && !memo[quotientPosition]) 842 memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; 843 844 return true; 845 } 846 } 847 return false; 848 } 849 850 /// Check if the pos^th variable can be expressed as a floordiv of an affine 851 /// function of other variables (where the divisor is a positive constant) 852 /// given the initial set of expressions in `exprs`. If it can be, the 853 /// corresponding position in `exprs` is set as the detected affine expr. For 854 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can 855 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 856 /// <= i <= 32q + 31 => q = i floordiv 32. 857 static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst, 858 unsigned pos, MLIRContext *context, 859 SmallVectorImpl<AffineExpr> &exprs) { 860 assert(pos < cst.getNumVars() && "invalid position"); 861 862 // Get upper-lower bound pair for this variable. 863 SmallVector<bool, 8> foundRepr(cst.getNumVars(), false); 864 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) 865 if (exprs[i]) 866 foundRepr[i] = true; 867 868 SmallVector<int64_t, 8> dividend(cst.getNumCols()); 869 unsigned divisor; 870 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); 871 872 // No upper-lower bound pair found for this var. 873 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) 874 return false; 875 876 // Construct the dividend expression. 877 auto dividendExpr = getAffineConstantExpr(dividend.back(), context); 878 for (unsigned c = 0, f = cst.getNumVars(); c < f; c++) 879 if (dividend[c] != 0) 880 dividendExpr = dividendExpr + dividend[c] * exprs[c]; 881 882 // Successfully detected the floordiv. 883 exprs[pos] = dividendExpr.floorDiv(divisor); 884 return true; 885 } 886 887 std::pair<AffineMap, AffineMap> 888 FlatAffineValueConstraints::getLowerAndUpperBound( 889 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 890 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const { 891 assert(pos + offset < getNumDimVars() && "invalid dim start pos"); 892 assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 893 assert(getNumLocalVars() == localExprs.size() && 894 "incorrect local exprs count"); 895 896 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 897 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices, 898 offset, num); 899 900 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 901 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 902 b.clear(); 903 for (unsigned i = 0, e = a.size(); i < e; ++i) { 904 if (i < offset || i >= offset + num) 905 b.push_back(a[i]); 906 } 907 }; 908 909 SmallVector<int64_t, 8> lb, ub; 910 SmallVector<AffineExpr, 4> lbExprs; 911 unsigned dimCount = symStartPos - num; 912 unsigned symCount = getNumDimAndSymbolVars() - symStartPos; 913 lbExprs.reserve(lbIndices.size() + eqIndices.size()); 914 // Lower bound expressions. 915 for (auto idx : lbIndices) { 916 auto ineq = getInequality(idx); 917 // Extract the lower bound (in terms of other coeff's + const), i.e., if 918 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 919 // - 1. 920 addCoeffs(ineq, lb); 921 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 922 auto expr = 923 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); 924 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor 925 int64_t divisor = std::abs(ineq[pos + offset]); 926 expr = (expr + divisor - 1).floorDiv(divisor); 927 lbExprs.push_back(expr); 928 } 929 930 SmallVector<AffineExpr, 4> ubExprs; 931 ubExprs.reserve(ubIndices.size() + eqIndices.size()); 932 // Upper bound expressions. 933 for (auto idx : ubIndices) { 934 auto ineq = getInequality(idx); 935 // Extract the upper bound (in terms of other coeff's + const). 936 addCoeffs(ineq, ub); 937 auto expr = 938 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); 939 expr = expr.floorDiv(std::abs(ineq[pos + offset])); 940 // Upper bound is exclusive. 941 ubExprs.push_back(expr + 1); 942 } 943 944 // Equalities. It's both a lower and a upper bound. 945 SmallVector<int64_t, 4> b; 946 for (auto idx : eqIndices) { 947 auto eq = getEquality(idx); 948 addCoeffs(eq, b); 949 if (eq[pos + offset] > 0) 950 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>()); 951 952 // Extract the upper bound (in terms of other coeff's + const). 953 auto expr = 954 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 955 expr = expr.floorDiv(std::abs(eq[pos + offset])); 956 // Upper bound is exclusive. 957 ubExprs.push_back(expr + 1); 958 // Lower bound. 959 expr = 960 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 961 expr = expr.ceilDiv(std::abs(eq[pos + offset])); 962 lbExprs.push_back(expr); 963 } 964 965 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); 966 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); 967 968 return {lbMap, ubMap}; 969 } 970 971 /// Computes the lower and upper bounds of the first 'num' dimensional 972 /// variables (starting at 'offset') as affine maps of the remaining 973 /// variables (dimensional and symbolic variables). Local variables are 974 /// themselves explicitly computed as affine functions of other variables in 975 /// this process if needed. 976 void FlatAffineValueConstraints::getSliceBounds( 977 unsigned offset, unsigned num, MLIRContext *context, 978 SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps, 979 bool getClosedUB) { 980 assert(num < getNumDimVars() && "invalid range"); 981 982 // Basic simplification. 983 normalizeConstraintsByGCD(); 984 985 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 986 << " variables\n"); 987 LLVM_DEBUG(dump()); 988 989 // Record computed/detected variables. 990 SmallVector<AffineExpr, 8> memo(getNumVars()); 991 // Initialize dimensional and symbolic variables. 992 for (unsigned i = 0, e = getNumDimVars(); i < e; i++) { 993 if (i < offset) 994 memo[i] = getAffineDimExpr(i, context); 995 else if (i >= offset + num) 996 memo[i] = getAffineDimExpr(i - num, context); 997 } 998 for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) 999 memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context); 1000 1001 bool changed; 1002 do { 1003 changed = false; 1004 // Identify yet unknown variables as constants or mod's / floordiv's of 1005 // other variables if possible. 1006 for (unsigned pos = 0; pos < getNumVars(); pos++) { 1007 if (memo[pos]) 1008 continue; 1009 1010 auto lbConst = getConstantBound(BoundType::LB, pos); 1011 auto ubConst = getConstantBound(BoundType::UB, pos); 1012 if (lbConst.hasValue() && ubConst.hasValue()) { 1013 // Detect equality to a constant. 1014 if (lbConst.getValue() == ubConst.getValue()) { 1015 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); 1016 changed = true; 1017 continue; 1018 } 1019 1020 // Detect an variable as modulo of another variable w.r.t a 1021 // constant. 1022 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), 1023 memo, context)) { 1024 changed = true; 1025 continue; 1026 } 1027 } 1028 1029 // Detect an variable as a floordiv of an affine function of other 1030 // variables (divisor is a positive constant). 1031 if (detectAsFloorDiv(*this, pos, context, memo)) { 1032 changed = true; 1033 continue; 1034 } 1035 1036 // Detect an variable as an expression of other variables. 1037 unsigned idx; 1038 if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { 1039 continue; 1040 } 1041 1042 // Build AffineExpr solving for variable 'pos' in terms of all others. 1043 auto expr = getAffineConstantExpr(0, context); 1044 unsigned j, e; 1045 for (j = 0, e = getNumVars(); j < e; ++j) { 1046 if (j == pos) 1047 continue; 1048 int64_t c = atEq(idx, j); 1049 if (c == 0) 1050 continue; 1051 // If any of the involved IDs hasn't been found yet, we can't proceed. 1052 if (!memo[j]) 1053 break; 1054 expr = expr + memo[j] * c; 1055 } 1056 if (j < e) 1057 // Can't construct expression as it depends on a yet uncomputed 1058 // variable. 1059 continue; 1060 1061 // Add constant term to AffineExpr. 1062 expr = expr + atEq(idx, getNumVars()); 1063 int64_t vPos = atEq(idx, pos); 1064 assert(vPos != 0 && "expected non-zero here"); 1065 if (vPos > 0) 1066 expr = (-expr).floorDiv(vPos); 1067 else 1068 // vPos < 0. 1069 expr = expr.floorDiv(-vPos); 1070 // Successfully constructed expression. 1071 memo[pos] = expr; 1072 changed = true; 1073 } 1074 // This loop is guaranteed to reach a fixed point - since once an 1075 // variable's explicit form is computed (in memo[pos]), it's not updated 1076 // again. 1077 } while (changed); 1078 1079 int64_t ubAdjustment = getClosedUB ? 0 : 1; 1080 1081 // Set the lower and upper bound maps for all the variables that were 1082 // computed as affine expressions of the rest as the "detected expr" and 1083 // "detected expr + 1" respectively; set the undetected ones to null. 1084 Optional<FlatAffineValueConstraints> tmpClone; 1085 for (unsigned pos = 0; pos < num; pos++) { 1086 unsigned numMapDims = getNumDimVars() - num; 1087 unsigned numMapSymbols = getNumSymbolVars(); 1088 AffineExpr expr = memo[pos + offset]; 1089 if (expr) 1090 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 1091 1092 AffineMap &lbMap = (*lbMaps)[pos]; 1093 AffineMap &ubMap = (*ubMaps)[pos]; 1094 1095 if (expr) { 1096 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 1097 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment); 1098 } else { 1099 // TODO: Whenever there are local variables in the dependence 1100 // constraints, we'll conservatively over-approximate, since we don't 1101 // always explicitly compute them above (in the while loop). 1102 if (getNumLocalVars() == 0) { 1103 // Work on a copy so that we don't update this constraint system. 1104 if (!tmpClone) { 1105 tmpClone.emplace(FlatAffineValueConstraints(*this)); 1106 // Removing redundant inequalities is necessary so that we don't get 1107 // redundant loop bounds. 1108 tmpClone->removeRedundantInequalities(); 1109 } 1110 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 1111 pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context); 1112 } 1113 1114 // If the above fails, we'll just use the constant lower bound and the 1115 // constant upper bound (if they exist) as the slice bounds. 1116 // TODO: being conservative for the moment in cases that 1117 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 1118 // fixed (b/126426796). 1119 if (!lbMap || lbMap.getNumResults() > 1) { 1120 LLVM_DEBUG(llvm::dbgs() 1121 << "WARNING: Potentially over-approximating slice lb\n"); 1122 auto lbConst = getConstantBound(BoundType::LB, pos + offset); 1123 if (lbConst.hasValue()) { 1124 lbMap = AffineMap::get( 1125 numMapDims, numMapSymbols, 1126 getAffineConstantExpr(lbConst.getValue(), context)); 1127 } 1128 } 1129 if (!ubMap || ubMap.getNumResults() > 1) { 1130 LLVM_DEBUG(llvm::dbgs() 1131 << "WARNING: Potentially over-approximating slice ub\n"); 1132 auto ubConst = getConstantBound(BoundType::UB, pos + offset); 1133 if (ubConst.hasValue()) { 1134 ubMap = 1135 AffineMap::get(numMapDims, numMapSymbols, 1136 getAffineConstantExpr( 1137 ubConst.getValue() + ubAdjustment, context)); 1138 } 1139 } 1140 } 1141 LLVM_DEBUG(llvm::dbgs() 1142 << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 1143 LLVM_DEBUG(lbMap.dump();); 1144 LLVM_DEBUG(llvm::dbgs() 1145 << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 1146 LLVM_DEBUG(ubMap.dump();); 1147 } 1148 } 1149 1150 LogicalResult FlatAffineValueConstraints::flattenAlignedMapAndMergeLocals( 1151 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { 1152 FlatAffineValueConstraints localCst; 1153 if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) { 1154 LLVM_DEBUG(llvm::dbgs() 1155 << "composition unimplemented for semi-affine maps\n"); 1156 return failure(); 1157 } 1158 1159 // Add localCst information. 1160 if (localCst.getNumLocalVars() > 0) { 1161 unsigned numLocalVars = getNumLocalVars(); 1162 // Insert local dims of localCst at the beginning. 1163 insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars()); 1164 // Insert local dims of `this` at the end of localCst. 1165 localCst.appendLocalVar(/*num=*/numLocalVars); 1166 // Dimensions of localCst and this constraint set match. Append localCst to 1167 // this constraint set. 1168 append(localCst); 1169 } 1170 1171 return success(); 1172 } 1173 1174 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1175 AffineMap boundMap, 1176 bool isClosedBound) { 1177 assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch"); 1178 assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 1179 assert(pos < getNumDimAndSymbolVars() && "invalid position"); 1180 assert((type != BoundType::EQ || isClosedBound) && 1181 "EQ bound must be closed."); 1182 1183 // Equality follows the logic of lower bound except that we add an equality 1184 // instead of an inequality. 1185 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && 1186 "single result expected"); 1187 bool lower = type == BoundType::LB || type == BoundType::EQ; 1188 1189 std::vector<SmallVector<int64_t, 8>> flatExprs; 1190 if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs))) 1191 return failure(); 1192 assert(flatExprs.size() == boundMap.getNumResults()); 1193 1194 // Add one (in)equality for each result. 1195 for (const auto &flatExpr : flatExprs) { 1196 SmallVector<int64_t> ineq(getNumCols(), 0); 1197 // Dims and symbols. 1198 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { 1199 ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; 1200 } 1201 // Invalid bound: pos appears in `boundMap`. 1202 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or 1203 // its callers to prevent invalid bounds from being added. 1204 if (ineq[pos] != 0) 1205 continue; 1206 ineq[pos] = lower ? 1 : -1; 1207 // Local columns of `ineq` are at the beginning. 1208 unsigned j = getNumDimVars() + getNumSymbolVars(); 1209 unsigned end = flatExpr.size() - 1; 1210 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { 1211 ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; 1212 } 1213 // Make the bound closed in if flatExpr is open. The inequality is always 1214 // created in the upper bound form, so the adjustment is -1. 1215 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; 1216 // Constant term. 1217 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] 1218 : flatExpr[flatExpr.size() - 1]) + 1219 boundAdjustment; 1220 type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); 1221 } 1222 1223 return success(); 1224 } 1225 1226 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1227 AffineMap boundMap) { 1228 return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB); 1229 } 1230 1231 AffineMap 1232 FlatAffineValueConstraints::computeAlignedMap(AffineMap map, 1233 ValueRange operands) const { 1234 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); 1235 1236 SmallVector<Value> dims, syms; 1237 #ifndef NDEBUG 1238 SmallVector<Value> newSyms; 1239 SmallVector<Value> *newSymsPtr = &newSyms; 1240 #else 1241 SmallVector<Value> *newSymsPtr = nullptr; 1242 #endif // NDEBUG 1243 1244 dims.reserve(getNumDimVars()); 1245 syms.reserve(getNumSymbolVars()); 1246 for (unsigned i = getVarKindOffset(VarKind::SetDim), 1247 e = getVarKindEnd(VarKind::SetDim); 1248 i < e; ++i) 1249 dims.push_back(values[i] ? *values[i] : Value()); 1250 for (unsigned i = getVarKindOffset(VarKind::Symbol), 1251 e = getVarKindEnd(VarKind::Symbol); 1252 i < e; ++i) 1253 syms.push_back(values[i] ? *values[i] : Value()); 1254 1255 AffineMap alignedMap = 1256 alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); 1257 // All symbols are already part of this FlatAffineConstraints. 1258 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); 1259 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && 1260 "unexpected new/missing symbols"); 1261 return alignedMap; 1262 } 1263 1264 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1265 AffineMap boundMap, 1266 ValueRange boundOperands) { 1267 // Fully compose map and operands; canonicalize and simplify so that we 1268 // transitively get to terminal symbols or loop IVs. 1269 auto map = boundMap; 1270 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); 1271 fullyComposeAffineMapAndOperands(&map, &operands); 1272 map = simplifyAffineMap(map); 1273 canonicalizeMapAndOperands(&map, &operands); 1274 for (auto operand : operands) 1275 addInductionVarOrTerminalSymbol(operand); 1276 return addBound(type, pos, computeAlignedMap(map, operands)); 1277 } 1278 1279 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper 1280 // bounds in 'ubMaps' to each value in `values' that appears in the constraint 1281 // system. Note that both lower/upper bounds share the same operand list 1282 // 'operands'. 1283 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and 1284 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. 1285 // Note that both lower/upper bounds use operands from 'operands'. 1286 // Returns failure for unimplemented cases such as semi-affine expressions or 1287 // expressions with mod/floordiv. 1288 LogicalResult FlatAffineValueConstraints::addSliceBounds( 1289 ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps, 1290 ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) { 1291 assert(values.size() == lbMaps.size()); 1292 assert(lbMaps.size() == ubMaps.size()); 1293 1294 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 1295 unsigned pos; 1296 if (!findVar(values[i], &pos)) 1297 continue; 1298 1299 AffineMap lbMap = lbMaps[i]; 1300 AffineMap ubMap = ubMaps[i]; 1301 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 1302 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 1303 1304 // Check if this slice is just an equality along this dimension. 1305 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 1306 ubMap.getNumResults() == 1 && 1307 lbMap.getResult(0) + 1 == ubMap.getResult(0)) { 1308 if (failed(addBound(BoundType::EQ, pos, lbMap, operands))) 1309 return failure(); 1310 continue; 1311 } 1312 1313 // If lower or upper bound maps are null or provide no results, it implies 1314 // that the source loop was not at all sliced, and the entire loop will be a 1315 // part of the slice. 1316 if (lbMap && lbMap.getNumResults() != 0 && ubMap && 1317 ubMap.getNumResults() != 0) { 1318 if (failed(addBound(BoundType::LB, pos, lbMap, operands))) 1319 return failure(); 1320 if (failed(addBound(BoundType::UB, pos, ubMap, operands))) 1321 return failure(); 1322 } else { 1323 auto loop = getForInductionVarOwner(values[i]); 1324 if (failed(this->addAffineForOpDomain(loop))) 1325 return failure(); 1326 } 1327 } 1328 return success(); 1329 } 1330 1331 bool FlatAffineValueConstraints::findVar(Value val, unsigned *pos) const { 1332 unsigned i = 0; 1333 for (const auto &mayBeVar : values) { 1334 if (mayBeVar && *mayBeVar == val) { 1335 *pos = i; 1336 return true; 1337 } 1338 i++; 1339 } 1340 return false; 1341 } 1342 1343 bool FlatAffineValueConstraints::containsVar(Value val) const { 1344 return llvm::any_of(values, [&](const Optional<Value> &mayBeVar) { 1345 return mayBeVar && *mayBeVar == val; 1346 }); 1347 } 1348 1349 void FlatAffineValueConstraints::swapVar(unsigned posA, unsigned posB) { 1350 IntegerPolyhedron::swapVar(posA, posB); 1351 1352 if (getVarKindAt(posA) == VarKind::Local && 1353 getVarKindAt(posB) == VarKind::Local) 1354 return; 1355 1356 // Treat value of a local variable as None. 1357 if (getVarKindAt(posA) == VarKind::Local) 1358 values[posB] = None; 1359 else if (getVarKindAt(posB) == VarKind::Local) 1360 values[posA] = None; 1361 else 1362 std::swap(values[posA], values[posB]); 1363 } 1364 1365 void FlatAffineValueConstraints::addBound(BoundType type, Value val, 1366 int64_t value) { 1367 unsigned pos; 1368 if (!findVar(val, &pos)) 1369 // This is a pre-condition for this method. 1370 assert(0 && "var not found"); 1371 addBound(type, pos, value); 1372 } 1373 1374 void FlatAffineValueConstraints::printSpace(raw_ostream &os) const { 1375 IntegerPolyhedron::printSpace(os); 1376 os << "("; 1377 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) { 1378 if (hasValue(i)) 1379 os << "Value "; 1380 else 1381 os << "None "; 1382 } 1383 for (unsigned i = getVarKindOffset(VarKind::Local), 1384 e = getVarKindEnd(VarKind::Local); 1385 i < e; ++i) 1386 os << "Local "; 1387 os << " const)\n"; 1388 } 1389 1390 void FlatAffineValueConstraints::clearAndCopyFrom( 1391 const IntegerRelation &other) { 1392 1393 if (auto *otherValueSet = 1394 dyn_cast<const FlatAffineValueConstraints>(&other)) { 1395 *this = *otherValueSet; 1396 } else { 1397 *static_cast<IntegerRelation *>(this) = other; 1398 values.clear(); 1399 values.resize(getNumDimAndSymbolVars(), None); 1400 } 1401 } 1402 1403 void FlatAffineValueConstraints::fourierMotzkinEliminate( 1404 unsigned pos, bool darkShadow, bool *isResultIntegerExact) { 1405 SmallVector<Optional<Value>, 8> newVals = values; 1406 if (getVarKindAt(pos) != VarKind::Local) 1407 newVals.erase(newVals.begin() + pos); 1408 // Note: Base implementation discards all associated Values. 1409 IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow, 1410 isResultIntegerExact); 1411 values = newVals; 1412 assert(values.size() == getNumDimAndSymbolVars()); 1413 } 1414 1415 void FlatAffineValueConstraints::projectOut(Value val) { 1416 unsigned pos; 1417 bool ret = findVar(val, &pos); 1418 assert(ret); 1419 (void)ret; 1420 fourierMotzkinEliminate(pos); 1421 } 1422 1423 LogicalResult FlatAffineValueConstraints::unionBoundingBox( 1424 const FlatAffineValueConstraints &otherCst) { 1425 assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch"); 1426 assert(otherCst.getMaybeValues() 1427 .slice(0, getNumDimVars()) 1428 .equals(getMaybeValues().slice(0, getNumDimVars())) && 1429 "dim values mismatch"); 1430 assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here"); 1431 assert(getNumLocalVars() == 0 && "local vars not supported yet here"); 1432 1433 // Align `other` to this. 1434 if (!areVarsAligned(*this, otherCst)) { 1435 FlatAffineValueConstraints otherCopy(otherCst); 1436 mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy); 1437 return IntegerPolyhedron::unionBoundingBox(otherCopy); 1438 } 1439 1440 return IntegerPolyhedron::unionBoundingBox(otherCst); 1441 } 1442 1443 /// Compute an explicit representation for local vars. For all systems coming 1444 /// from MLIR integer sets, maps, or expressions where local vars were 1445 /// introduced to model floordivs and mods, this always succeeds. 1446 static LogicalResult computeLocalVars(const FlatAffineValueConstraints &cst, 1447 SmallVectorImpl<AffineExpr> &memo, 1448 MLIRContext *context) { 1449 unsigned numDims = cst.getNumDimVars(); 1450 unsigned numSyms = cst.getNumSymbolVars(); 1451 1452 // Initialize dimensional and symbolic variables. 1453 for (unsigned i = 0; i < numDims; i++) 1454 memo[i] = getAffineDimExpr(i, context); 1455 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) 1456 memo[i] = getAffineSymbolExpr(i - numDims, context); 1457 1458 bool changed; 1459 do { 1460 // Each time `changed` is true at the end of this iteration, one or more 1461 // local vars would have been detected as floordivs and set in memo; so the 1462 // number of null entries in memo[...] strictly reduces; so this converges. 1463 changed = false; 1464 for (unsigned i = 0, e = cst.getNumLocalVars(); i < e; ++i) 1465 if (!memo[numDims + numSyms + i] && 1466 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo)) 1467 changed = true; 1468 } while (changed); 1469 1470 ArrayRef<AffineExpr> localExprs = 1471 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalVars()); 1472 return success( 1473 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); 1474 } 1475 1476 void FlatAffineValueConstraints::getIneqAsAffineValueMap( 1477 unsigned pos, unsigned ineqPos, AffineValueMap &vmap, 1478 MLIRContext *context) const { 1479 unsigned numDims = getNumDimVars(); 1480 unsigned numSyms = getNumSymbolVars(); 1481 1482 assert(pos < numDims && "invalid position"); 1483 assert(ineqPos < getNumInequalities() && "invalid inequality position"); 1484 1485 // Get expressions for local vars. 1486 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); 1487 if (failed(computeLocalVars(*this, memo, context))) 1488 assert(false && 1489 "one or more local exprs do not have an explicit representation"); 1490 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 1491 1492 // Compute the AffineExpr lower/upper bound for this inequality. 1493 ArrayRef<int64_t> inequality = getInequality(ineqPos); 1494 SmallVector<int64_t, 8> bound; 1495 bound.reserve(getNumCols() - 1); 1496 // Everything other than the coefficient at `pos`. 1497 bound.append(inequality.begin(), inequality.begin() + pos); 1498 bound.append(inequality.begin() + pos + 1, inequality.end()); 1499 1500 if (inequality[pos] > 0) 1501 // Lower bound. 1502 std::transform(bound.begin(), bound.end(), bound.begin(), 1503 std::negate<int64_t>()); 1504 else 1505 // Upper bound (which is exclusive). 1506 bound.back() += 1; 1507 1508 // Convert to AffineExpr (tree) form. 1509 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms, 1510 localExprs, context); 1511 1512 // Get the values to bind to this affine expr (all dims and symbols). 1513 SmallVector<Value, 4> operands; 1514 getValues(0, pos, &operands); 1515 SmallVector<Value, 4> trailingOperands; 1516 getValues(pos + 1, getNumDimAndSymbolVars(), &trailingOperands); 1517 operands.append(trailingOperands.begin(), trailingOperands.end()); 1518 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands); 1519 } 1520 1521 IntegerSet 1522 FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const { 1523 if (getNumConstraints() == 0) 1524 // Return universal set (always true): 0 == 0. 1525 return IntegerSet::get(getNumDimVars(), getNumSymbolVars(), 1526 getAffineConstantExpr(/*constant=*/0, context), 1527 /*eqFlags=*/true); 1528 1529 // Construct local references. 1530 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); 1531 1532 if (failed(computeLocalVars(*this, memo, context))) { 1533 // Check if the local variables without an explicit representation have 1534 // zero coefficients everywhere. 1535 SmallVector<unsigned> noLocalRepVars; 1536 unsigned numDimsSymbols = getNumDimAndSymbolVars(); 1537 for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) { 1538 if (!memo[i] && !isColZero(/*pos=*/i)) 1539 noLocalRepVars.push_back(i - numDimsSymbols); 1540 } 1541 if (!noLocalRepVars.empty()) { 1542 LLVM_DEBUG({ 1543 llvm::dbgs() << "local variables at position(s) "; 1544 llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); 1545 llvm::dbgs() << " do not have an explicit representation in:\n"; 1546 this->dump(); 1547 }); 1548 return IntegerSet(); 1549 } 1550 } 1551 1552 ArrayRef<AffineExpr> localExprs = 1553 ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 1554 1555 // Construct the IntegerSet from the equalities/inequalities. 1556 unsigned numDims = getNumDimVars(); 1557 unsigned numSyms = getNumSymbolVars(); 1558 1559 SmallVector<bool, 16> eqFlags(getNumConstraints()); 1560 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); 1561 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); 1562 1563 SmallVector<AffineExpr, 8> exprs; 1564 exprs.reserve(getNumConstraints()); 1565 1566 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1567 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms, 1568 localExprs, context)); 1569 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 1570 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims, 1571 numSyms, localExprs, context)); 1572 return IntegerSet::get(numDims, numSyms, exprs, eqFlags); 1573 } 1574 1575 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, 1576 ValueRange dims, ValueRange syms, 1577 SmallVector<Value> *newSyms) { 1578 assert(operands.size() == map.getNumInputs() && 1579 "expected same number of operands and map inputs"); 1580 MLIRContext *ctx = map.getContext(); 1581 Builder builder(ctx); 1582 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); 1583 unsigned numSymbols = syms.size(); 1584 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); 1585 if (newSyms) { 1586 newSyms->clear(); 1587 newSyms->append(syms.begin(), syms.end()); 1588 } 1589 1590 for (const auto &operand : llvm::enumerate(operands)) { 1591 // Compute replacement dim/sym of operand. 1592 AffineExpr replacement; 1593 auto dimIt = std::find(dims.begin(), dims.end(), operand.value()); 1594 auto symIt = std::find(syms.begin(), syms.end(), operand.value()); 1595 if (dimIt != dims.end()) { 1596 replacement = 1597 builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); 1598 } else if (symIt != syms.end()) { 1599 replacement = 1600 builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); 1601 } else { 1602 // This operand is neither a dimension nor a symbol. Add it as a new 1603 // symbol. 1604 replacement = builder.getAffineSymbolExpr(numSymbols++); 1605 if (newSyms) 1606 newSyms->push_back(operand.value()); 1607 } 1608 // Add to corresponding replacements vector. 1609 if (operand.index() < map.getNumDims()) { 1610 dimReplacements[operand.index()] = replacement; 1611 } else { 1612 symReplacements[operand.index() - map.getNumDims()] = replacement; 1613 } 1614 } 1615 1616 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, 1617 dims.size(), numSymbols); 1618 } 1619 1620 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { 1621 FlatAffineValueConstraints domain = *this; 1622 // Convert all range variables to local variables. 1623 domain.convertToLocal(VarKind::SetDim, getNumDomainDims(), 1624 getNumDomainDims() + getNumRangeDims()); 1625 return domain; 1626 } 1627 1628 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { 1629 FlatAffineValueConstraints range = *this; 1630 // Convert all domain variables to local variables. 1631 range.convertToLocal(VarKind::SetDim, 0, getNumDomainDims()); 1632 return range; 1633 } 1634 1635 void FlatAffineRelation::compose(const FlatAffineRelation &other) { 1636 assert(getNumDomainDims() == other.getNumRangeDims() && 1637 "Domain of this and range of other do not match"); 1638 assert(std::equal(values.begin(), values.begin() + getNumDomainDims(), 1639 other.values.begin() + other.getNumDomainDims()) && 1640 "Domain of this and range of other do not match"); 1641 1642 FlatAffineRelation rel = other; 1643 1644 // Convert `rel` from 1645 // [otherDomain] -> [otherRange] 1646 // to 1647 // [otherDomain] -> [otherRange thisRange] 1648 // and `this` from 1649 // [thisDomain] -> [thisRange] 1650 // to 1651 // [otherDomain thisDomain] -> [thisRange]. 1652 unsigned removeDims = rel.getNumRangeDims(); 1653 insertDomainVar(0, rel.getNumDomainDims()); 1654 rel.appendRangeVar(getNumRangeDims()); 1655 1656 // Merge symbol and local variables. 1657 mergeSymbolVars(rel); 1658 mergeLocalVars(rel); 1659 1660 // Convert `rel` from [otherDomain] -> [otherRange thisRange] to 1661 // [otherDomain] -> [thisRange] by converting first otherRange range vars 1662 // to local vars. 1663 rel.convertToLocal(VarKind::SetDim, rel.getNumDomainDims(), 1664 rel.getNumDomainDims() + removeDims); 1665 // Convert `this` from [otherDomain thisDomain] -> [thisRange] to 1666 // [otherDomain] -> [thisRange] by converting last thisDomain domain vars 1667 // to local vars. 1668 convertToLocal(VarKind::SetDim, getNumDomainDims() - removeDims, 1669 getNumDomainDims()); 1670 1671 auto thisMaybeValues = getMaybeValues(VarKind::SetDim); 1672 auto relMaybeValues = rel.getMaybeValues(VarKind::SetDim); 1673 1674 // Add and match domain of `rel` to domain of `this`. 1675 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1676 if (relMaybeValues[i].hasValue()) 1677 setValue(i, relMaybeValues[i].getValue()); 1678 // Add and match range of `this` to range of `rel`. 1679 for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { 1680 unsigned rangeIdx = rel.getNumDomainDims() + i; 1681 if (thisMaybeValues[rangeIdx].hasValue()) 1682 rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue()); 1683 } 1684 1685 // Append `this` to `rel` and simplify constraints. 1686 rel.append(*this); 1687 rel.removeRedundantLocalVars(); 1688 1689 *this = rel; 1690 } 1691 1692 void FlatAffineRelation::inverse() { 1693 unsigned oldDomain = getNumDomainDims(); 1694 unsigned oldRange = getNumRangeDims(); 1695 // Add new range vars. 1696 appendRangeVar(oldDomain); 1697 // Swap new vars with domain. 1698 for (unsigned i = 0; i < oldDomain; ++i) 1699 swapVar(i, oldDomain + oldRange + i); 1700 // Remove the swapped domain. 1701 removeVarRange(0, oldDomain); 1702 // Set domain and range as inverse. 1703 numDomainDims = oldRange; 1704 numRangeDims = oldDomain; 1705 } 1706 1707 void FlatAffineRelation::insertDomainVar(unsigned pos, unsigned num) { 1708 assert(pos <= getNumDomainDims() && 1709 "Var cannot be inserted at invalid position"); 1710 insertDimVar(pos, num); 1711 numDomainDims += num; 1712 } 1713 1714 void FlatAffineRelation::insertRangeVar(unsigned pos, unsigned num) { 1715 assert(pos <= getNumRangeDims() && 1716 "Var cannot be inserted at invalid position"); 1717 insertDimVar(getNumDomainDims() + pos, num); 1718 numRangeDims += num; 1719 } 1720 1721 void FlatAffineRelation::appendDomainVar(unsigned num) { 1722 insertDimVar(getNumDomainDims(), num); 1723 numDomainDims += num; 1724 } 1725 1726 void FlatAffineRelation::appendRangeVar(unsigned num) { 1727 insertDimVar(getNumDimVars(), num); 1728 numRangeDims += num; 1729 } 1730 1731 void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart, 1732 unsigned varLimit) { 1733 assert(varLimit <= getNumVarKind(kind)); 1734 if (varStart >= varLimit) 1735 return; 1736 1737 FlatAffineValueConstraints::removeVarRange(kind, varStart, varLimit); 1738 1739 // If kind is not SetDim, domain and range don't need to be updated. 1740 if (kind != VarKind::SetDim) 1741 return; 1742 1743 // Compute number of domain and range variables to remove. This is done by 1744 // intersecting the range of domain/range vars with range of vars to remove. 1745 unsigned intersectDomainLHS = std::min(varLimit, getNumDomainDims()); 1746 unsigned intersectDomainRHS = varStart; 1747 unsigned intersectRangeLHS = std::min(varLimit, getNumDimVars()); 1748 unsigned intersectRangeRHS = std::max(varStart, getNumDomainDims()); 1749 1750 if (intersectDomainLHS > intersectDomainRHS) 1751 numDomainDims -= intersectDomainLHS - intersectDomainRHS; 1752 if (intersectRangeLHS > intersectRangeRHS) 1753 numRangeDims -= intersectRangeLHS - intersectRangeRHS; 1754 } 1755 1756 LogicalResult mlir::getRelationFromMap(AffineMap &map, 1757 FlatAffineRelation &rel) { 1758 // Get flattened affine expressions. 1759 std::vector<SmallVector<int64_t, 8>> flatExprs; 1760 FlatAffineValueConstraints localVarCst; 1761 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) 1762 return failure(); 1763 1764 unsigned oldDimNum = localVarCst.getNumDimVars(); 1765 unsigned oldCols = localVarCst.getNumCols(); 1766 unsigned numRangeVars = map.getNumResults(); 1767 unsigned numDomainVars = map.getNumDims(); 1768 1769 // Add range as the new expressions. 1770 localVarCst.appendDimVar(numRangeVars); 1771 1772 // Add equalities between source and range. 1773 SmallVector<int64_t, 8> eq(localVarCst.getNumCols()); 1774 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 1775 // Zero fill. 1776 std::fill(eq.begin(), eq.end(), 0); 1777 // Fill equality. 1778 for (unsigned j = 0, f = oldDimNum; j < f; ++j) 1779 eq[j] = flatExprs[i][j]; 1780 for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) 1781 eq[j + numRangeVars] = flatExprs[i][j]; 1782 // Set this dimension to -1 to equate lhs and rhs and add equality. 1783 eq[numDomainVars + i] = -1; 1784 localVarCst.addEquality(eq); 1785 } 1786 1787 // Create relation and return success. 1788 rel = FlatAffineRelation(numDomainVars, numRangeVars, localVarCst); 1789 return success(); 1790 } 1791 1792 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map, 1793 FlatAffineRelation &rel) { 1794 1795 AffineMap affineMap = map.getAffineMap(); 1796 if (failed(getRelationFromMap(affineMap, rel))) 1797 return failure(); 1798 1799 // Set symbol values for domain dimensions and symbols. 1800 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1801 rel.setValue(i, map.getOperand(i)); 1802 for (unsigned i = rel.getNumDimVars(), e = rel.getNumDimAndSymbolVars(); 1803 i < e; ++i) 1804 rel.setValue(i, map.getOperand(i - rel.getNumRangeDims())); 1805 1806 return success(); 1807 } 1808