1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Structures for affine/polyhedral analysis of affine dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 14 #include "mlir/Analysis/Presburger/LinearTransform.h" 15 #include "mlir/Analysis/Presburger/Simplex.h" 16 #include "mlir/Analysis/Presburger/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Support/MathExtras.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #define DEBUG_TYPE "affine-structures" 31 32 using namespace mlir; 33 using namespace presburger; 34 35 namespace { 36 37 // See comments for SimpleAffineExprFlattener. 38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording 39 // constraint information associated with mod's, floordiv's, and ceildiv's 40 // in FlatAffineConstraints 'localVarCst'. 41 struct AffineExprFlattener : public SimpleAffineExprFlattener { 42 public: 43 // Constraints connecting newly introduced local variables (for mod's and 44 // div's) to existing (dimensional and symbolic) ones. These are always 45 // inequalities. 46 FlatAffineConstraints localVarCst; 47 48 AffineExprFlattener(unsigned nDims, unsigned nSymbols) 49 : SimpleAffineExprFlattener(nDims, nSymbols) { 50 localVarCst.reset(nDims, nSymbols, /*numLocals=*/0); 51 } 52 53 private: 54 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 55 // The local identifier added is always a floordiv of a pure add/mul affine 56 // function of other identifiers, coefficients of which are specified in 57 // `dividend' and with respect to the positive constant `divisor'. localExpr 58 // is the simplified tree expression (AffineExpr) corresponding to the 59 // quantifier. 60 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 61 AffineExpr localExpr) override { 62 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 63 // Update localVarCst. 64 localVarCst.addLocalFloorDiv(dividend, divisor); 65 } 66 }; 67 68 } // namespace 69 70 // Flattens the expressions in map. Returns failure if 'expr' was unable to be 71 // flattened (i.e., semi-affine expressions not handled yet). 72 static LogicalResult 73 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, 74 unsigned numSymbols, 75 std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 76 FlatAffineConstraints *localVarCst) { 77 if (exprs.empty()) { 78 localVarCst->reset(numDims, numSymbols); 79 return success(); 80 } 81 82 AffineExprFlattener flattener(numDims, numSymbols); 83 // Use the same flattener to simplify each expression successively. This way 84 // local identifiers / expressions are shared. 85 for (auto expr : exprs) { 86 if (!expr.isPureAffine()) 87 return failure(); 88 89 flattener.walkPostOrder(expr); 90 } 91 92 assert(flattener.operandExprStack.size() == exprs.size()); 93 flattenedExprs->clear(); 94 flattenedExprs->assign(flattener.operandExprStack.begin(), 95 flattener.operandExprStack.end()); 96 97 if (localVarCst) 98 localVarCst->clearAndCopyFrom(flattener.localVarCst); 99 100 return success(); 101 } 102 103 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 104 // be flattened (semi-affine expressions not handled yet). 105 LogicalResult 106 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, 107 unsigned numSymbols, 108 SmallVectorImpl<int64_t> *flattenedExpr, 109 FlatAffineConstraints *localVarCst) { 110 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 111 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, 112 &flattenedExprs, localVarCst); 113 *flattenedExpr = flattenedExprs[0]; 114 return ret; 115 } 116 117 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 118 /// flattened (i.e., semi-affine expressions not handled yet). 119 LogicalResult mlir::getFlattenedAffineExprs( 120 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 121 FlatAffineConstraints *localVarCst) { 122 if (map.getNumResults() == 0) { 123 localVarCst->reset(map.getNumDims(), map.getNumSymbols()); 124 return success(); 125 } 126 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), 127 map.getNumSymbols(), flattenedExprs, 128 localVarCst); 129 } 130 131 LogicalResult mlir::getFlattenedAffineExprs( 132 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 133 FlatAffineConstraints *localVarCst) { 134 if (set.getNumConstraints() == 0) { 135 localVarCst->reset(set.getNumDims(), set.getNumSymbols()); 136 return success(); 137 } 138 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 139 set.getNumSymbols(), flattenedExprs, 140 localVarCst); 141 } 142 143 //===----------------------------------------------------------------------===// 144 // FlatAffineConstraints / FlatAffineValueConstraints. 145 //===----------------------------------------------------------------------===// 146 147 // Clones this object. 148 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const { 149 return std::make_unique<FlatAffineConstraints>(*this); 150 } 151 152 std::unique_ptr<FlatAffineValueConstraints> 153 FlatAffineValueConstraints::clone() const { 154 return std::make_unique<FlatAffineValueConstraints>(*this); 155 } 156 157 // Construct from an IntegerSet. 158 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) 159 : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(), 160 set.getNumDims() + set.getNumSymbols() + 1, 161 set.getNumDims(), set.getNumSymbols(), 162 /*numLocals=*/0) { 163 164 // Flatten expressions and add them to the constraint system. 165 std::vector<SmallVector<int64_t, 8>> flatExprs; 166 FlatAffineConstraints localVarCst; 167 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 168 assert(false && "flattening unimplemented for semi-affine integer sets"); 169 return; 170 } 171 assert(flatExprs.size() == set.getNumConstraints()); 172 appendLocalId(/*num=*/localVarCst.getNumLocalIds()); 173 174 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 175 const auto &flatExpr = flatExprs[i]; 176 assert(flatExpr.size() == getNumCols()); 177 if (set.getEqFlags()[i]) { 178 addEquality(flatExpr); 179 } else { 180 addInequality(flatExpr); 181 } 182 } 183 // Add the other constraints involving local id's from flattening. 184 append(localVarCst); 185 } 186 187 // Construct from an IntegerSet. 188 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set) 189 : FlatAffineConstraints(set) { 190 values.resize(getNumIds(), None); 191 } 192 193 // Construct a hyperrectangular constraint set from ValueRanges that represent 194 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are 195 // expected to match one to one. The order of variables and constraints is: 196 // 197 // ivs | lbs | ubs | eq/ineq 198 // ----+-----+-----+--------- 199 // 1 -1 0 >= 0 200 // ----+-----+-----+--------- 201 // -1 0 1 >= 0 202 // 203 // All dimensions as set as DimId. 204 FlatAffineValueConstraints 205 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs, 206 ValueRange ubs) { 207 FlatAffineValueConstraints res; 208 unsigned nIvs = ivs.size(); 209 assert(nIvs == lbs.size() && "expected as many lower bounds as ivs"); 210 assert(nIvs == ubs.size() && "expected as many upper bounds as ivs"); 211 212 if (nIvs == 0) 213 return res; 214 215 res.appendDimId(ivs); 216 unsigned lbsStart = res.appendDimId(lbs); 217 unsigned ubsStart = res.appendDimId(ubs); 218 219 MLIRContext *ctx = ivs.front().getContext(); 220 for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) { 221 // iv - lb >= 0 222 AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 223 getAffineDimExpr(lbsStart + ivIdx, ctx)); 224 if (failed(res.addBound(BoundType::LB, ivIdx, lb))) 225 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 226 // -iv + ub >= 0 227 AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 228 getAffineDimExpr(ubsStart + ivIdx, ctx)); 229 if (failed(res.addBound(BoundType::UB, ivIdx, ub))) 230 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 231 } 232 return res; 233 } 234 235 void FlatAffineConstraints::reset(unsigned numReservedInequalities, 236 unsigned numReservedEqualities, 237 unsigned newNumReservedCols, 238 unsigned newNumDims, unsigned newNumSymbols, 239 unsigned newNumLocals) { 240 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 241 "minimum 1 column"); 242 *this = FlatAffineConstraints(numReservedInequalities, numReservedEqualities, 243 newNumReservedCols, newNumDims, newNumSymbols, 244 newNumLocals); 245 } 246 247 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, 248 unsigned newNumLocals) { 249 reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, 250 /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1, 251 newNumDims, newNumSymbols, newNumLocals); 252 } 253 254 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities, 255 unsigned numReservedEqualities, 256 unsigned newNumReservedCols, 257 unsigned newNumDims, 258 unsigned newNumSymbols, 259 unsigned newNumLocals) { 260 reset(numReservedInequalities, numReservedEqualities, newNumReservedCols, 261 newNumDims, newNumSymbols, newNumLocals, /*valArgs=*/{}); 262 } 263 264 void FlatAffineValueConstraints::reset( 265 unsigned numReservedInequalities, unsigned numReservedEqualities, 266 unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, 267 unsigned newNumLocals, ArrayRef<Value> valArgs) { 268 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 269 "minimum 1 column"); 270 SmallVector<Optional<Value>, 8> newVals; 271 if (!valArgs.empty()) 272 newVals.assign(valArgs.begin(), valArgs.end()); 273 274 *this = FlatAffineValueConstraints( 275 numReservedInequalities, numReservedEqualities, newNumReservedCols, 276 newNumDims, newNumSymbols, newNumLocals, newVals); 277 } 278 279 void FlatAffineValueConstraints::reset(unsigned newNumDims, 280 unsigned newNumSymbols, 281 unsigned newNumLocals, 282 ArrayRef<Value> valArgs) { 283 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, 284 newNumSymbols, newNumLocals, valArgs); 285 } 286 287 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) { 288 unsigned pos = getNumDimIds(); 289 insertId(IdKind::SetDim, pos, vals); 290 return pos; 291 } 292 293 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) { 294 unsigned pos = getNumSymbolIds(); 295 insertId(IdKind::Symbol, pos, vals); 296 return pos; 297 } 298 299 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos, 300 ValueRange vals) { 301 return insertId(IdKind::SetDim, pos, vals); 302 } 303 304 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos, 305 ValueRange vals) { 306 return insertId(IdKind::Symbol, pos, vals); 307 } 308 309 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos, 310 unsigned num) { 311 unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num); 312 values.insert(values.begin() + absolutePos, num, None); 313 assert(values.size() == getNumIds()); 314 return absolutePos; 315 } 316 317 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos, 318 ValueRange vals) { 319 assert(!vals.empty() && "expected ValueRange with Values"); 320 unsigned num = vals.size(); 321 unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num); 322 323 // If a Value is provided, insert it; otherwise use None. 324 for (unsigned i = 0; i < num; ++i) 325 values.insert(values.begin() + absolutePos + i, 326 vals[i] ? Optional<Value>(vals[i]) : None); 327 328 assert(values.size() == getNumIds()); 329 return absolutePos; 330 } 331 332 bool FlatAffineValueConstraints::hasValues() const { 333 return llvm::find_if(values, [](Optional<Value> id) { 334 return id.hasValue(); 335 }) != values.end(); 336 } 337 338 /// Checks if two constraint systems are in the same space, i.e., if they are 339 /// associated with the same set of identifiers, appearing in the same order. 340 static bool areIdsAligned(const FlatAffineValueConstraints &a, 341 const FlatAffineValueConstraints &b) { 342 return a.getNumDimIds() == b.getNumDimIds() && 343 a.getNumSymbolIds() == b.getNumSymbolIds() && 344 a.getNumIds() == b.getNumIds() && 345 a.getMaybeValues().equals(b.getMaybeValues()); 346 } 347 348 /// Calls areIdsAligned to check if two constraint systems have the same set 349 /// of identifiers in the same order. 350 bool FlatAffineValueConstraints::areIdsAlignedWithOther( 351 const FlatAffineValueConstraints &other) { 352 return areIdsAligned(*this, other); 353 } 354 355 /// Checks if the SSA values associated with `cst`'s identifiers in range 356 /// [start, end) are unique. 357 static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( 358 const FlatAffineValueConstraints &cst, unsigned start, unsigned end) { 359 360 assert(start <= cst.getNumIds() && "Start position out of bounds"); 361 assert(end <= cst.getNumIds() && "End position out of bounds"); 362 363 if (start >= end) 364 return true; 365 366 SmallPtrSet<Value, 8> uniqueIds; 367 ArrayRef<Optional<Value>> maybeValues = cst.getMaybeValues(); 368 for (Optional<Value> val : maybeValues) { 369 if (val.hasValue() && !uniqueIds.insert(val.getValue()).second) 370 return false; 371 } 372 return true; 373 } 374 375 /// Checks if the SSA values associated with `cst`'s identifiers are unique. 376 static bool LLVM_ATTRIBUTE_UNUSED 377 areIdsUnique(const FlatAffineConstraints &cst) { 378 return areIdsUnique(cst, 0, cst.getNumIds()); 379 } 380 381 /// Checks if the SSA values associated with `cst`'s identifiers of kind `kind` 382 /// are unique. 383 static bool LLVM_ATTRIBUTE_UNUSED 384 areIdsUnique(const FlatAffineValueConstraints &cst, IdKind kind) { 385 386 if (kind == IdKind::SetDim) 387 return areIdsUnique(cst, 0, cst.getNumDimIds()); 388 if (kind == IdKind::Symbol) 389 return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds()); 390 if (kind == IdKind::Local) 391 return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds()); 392 llvm_unreachable("Unexpected IdKind"); 393 } 394 395 /// Merge and align the identifiers of A and B starting at 'offset', so that 396 /// both constraint systems get the union of the contained identifiers that is 397 /// dimension-wise and symbol-wise unique; both constraint systems are updated 398 /// so that they have the union of all identifiers, with A's original 399 /// identifiers appearing first followed by any of B's identifiers that didn't 400 /// appear in A. Local identifiers in B that have the same division 401 /// representation as local identifiers in A are merged into one. 402 // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) 403 // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 404 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a, 405 FlatAffineValueConstraints *b) { 406 assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds()); 407 // A merge/align isn't meaningful if a cst's ids aren't distinct. 408 assert(areIdsUnique(*a) && "A's values aren't unique"); 409 assert(areIdsUnique(*b) && "B's values aren't unique"); 410 411 assert(std::all_of(a->getMaybeValues().begin() + offset, 412 a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(), 413 [](Optional<Value> id) { return id.hasValue(); })); 414 415 assert(std::all_of(b->getMaybeValues().begin() + offset, 416 b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(), 417 [](Optional<Value> id) { return id.hasValue(); })); 418 419 SmallVector<Value, 4> aDimValues; 420 a->getValues(offset, a->getNumDimIds(), &aDimValues); 421 422 { 423 // Merge dims from A into B. 424 unsigned d = offset; 425 for (auto aDimValue : aDimValues) { 426 unsigned loc; 427 if (b->findId(aDimValue, &loc)) { 428 assert(loc >= offset && "A's dim appears in B's aligned range"); 429 assert(loc < b->getNumDimIds() && 430 "A's dim appears in B's non-dim position"); 431 b->swapId(d, loc); 432 } else { 433 b->insertDimId(d, aDimValue); 434 } 435 d++; 436 } 437 // Dimensions that are in B, but not in A, are added at the end. 438 for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) { 439 a->appendDimId(b->getValue(t)); 440 } 441 assert(a->getNumDimIds() == b->getNumDimIds() && 442 "expected same number of dims"); 443 } 444 445 // Merge and align symbols of A and B 446 a->mergeSymbolIds(*b); 447 // Merge and align local ids of A and B 448 a->mergeLocalIds(*b); 449 450 assert(areIdsAligned(*a, *b) && "IDs expected to be aligned"); 451 } 452 453 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. 454 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther( 455 unsigned offset, FlatAffineValueConstraints *other) { 456 mergeAndAlignIds(offset, this, other); 457 } 458 459 LogicalResult 460 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) { 461 return composeMatchingMap( 462 computeAlignedMap(vMap->getAffineMap(), vMap->getOperands())); 463 } 464 465 // Similar to `composeMap` except that no Values need be associated with the 466 // constraint system nor are they looked at -- the dimensions and symbols of 467 // `other` are expected to correspond 1:1 to `this` system. 468 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) { 469 assert(other.getNumDims() == getNumDimIds() && "dim mismatch"); 470 assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); 471 472 std::vector<SmallVector<int64_t, 8>> flatExprs; 473 if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) 474 return failure(); 475 assert(flatExprs.size() == other.getNumResults()); 476 477 // Add dimensions corresponding to the map's results. 478 insertDimId(/*pos=*/0, /*num=*/other.getNumResults()); 479 480 // We add one equality for each result connecting the result dim of the map to 481 // the other identifiers. 482 // E.g.: if the expression is 16*i0 + i1, and this is the r^th 483 // iteration/result of the value map, we are adding the equality: 484 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we 485 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 486 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 487 const auto &flatExpr = flatExprs[r]; 488 assert(flatExpr.size() >= other.getNumInputs() + 1); 489 490 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 491 // Set the coefficient for this result to one. 492 eqToAdd[r] = 1; 493 494 // Dims and symbols. 495 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { 496 // Negate `eq[r]` since the newly added dimension will be set to this one. 497 eqToAdd[e + i] = -flatExpr[i]; 498 } 499 // Local columns of `eq` are at the beginning. 500 unsigned j = getNumDimIds() + getNumSymbolIds(); 501 unsigned end = flatExpr.size() - 1; 502 for (unsigned i = other.getNumInputs(); i < end; i++, j++) { 503 eqToAdd[j] = -flatExpr[i]; 504 } 505 506 // Constant term. 507 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 508 509 // Add the equality connecting the result of the map to this constraint set. 510 addEquality(eqToAdd); 511 } 512 513 return success(); 514 } 515 516 // Turn a symbol into a dimension. 517 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) { 518 unsigned pos; 519 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && 520 pos < cst->getNumDimAndSymbolIds()) { 521 cst->swapId(pos, cst->getNumDimIds()); 522 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1); 523 } 524 } 525 526 /// Merge and align symbols of `this` and `other` such that both get union of 527 /// of symbols that are unique. Symbols in `this` and `other` should be 528 /// unique. Symbols with Value as `None` are considered to be inequal to all 529 /// other symbols. 530 void FlatAffineValueConstraints::mergeSymbolIds( 531 FlatAffineValueConstraints &other) { 532 533 assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); 534 assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); 535 536 SmallVector<Value, 4> aSymValues; 537 getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); 538 539 // Merge symbols: merge symbols into `other` first from `this`. 540 unsigned s = other.getNumDimIds(); 541 for (Value aSymValue : aSymValues) { 542 unsigned loc; 543 // If the id is a symbol in `other`, then align it, otherwise assume that 544 // it is a new symbol 545 if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && 546 loc < other.getNumDimAndSymbolIds()) 547 other.swapId(s, loc); 548 else 549 other.insertSymbolId(s - other.getNumDimIds(), aSymValue); 550 s++; 551 } 552 553 // Symbols that are in other, but not in this, are added at the end. 554 for (unsigned t = other.getNumDimIds() + getNumSymbolIds(), 555 e = other.getNumDimAndSymbolIds(); 556 t < e; t++) 557 insertSymbolId(getNumSymbolIds(), other.getValue(t)); 558 559 assert(getNumSymbolIds() == other.getNumSymbolIds() && 560 "expected same number of symbols"); 561 assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); 562 assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); 563 } 564 565 // Changes all symbol identifiers which are loop IVs to dim identifiers. 566 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { 567 // Gather all symbols which are loop IVs. 568 SmallVector<Value, 4> loopIVs; 569 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { 570 if (hasValue(i) && getForInductionVarOwner(getValue(i))) 571 loopIVs.push_back(getValue(i)); 572 } 573 // Turn each symbol in 'loopIVs' into a dim identifier. 574 for (auto iv : loopIVs) { 575 turnSymbolIntoDim(this, iv); 576 } 577 } 578 579 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { 580 if (containsId(val)) 581 return; 582 583 // Caller is expected to fully compose map/operands if necessary. 584 assert((isTopLevelValue(val) || isForInductionVar(val)) && 585 "non-terminal symbol / loop IV expected"); 586 // Outer loop IVs could be used in forOp's bounds. 587 if (auto loop = getForInductionVarOwner(val)) { 588 appendDimId(val); 589 if (failed(this->addAffineForOpDomain(loop))) 590 LLVM_DEBUG( 591 loop.emitWarning("failed to add domain info to constraint system")); 592 return; 593 } 594 // Add top level symbol. 595 appendSymbolId(val); 596 // Check if the symbol is a constant. 597 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) 598 addBound(BoundType::EQ, val, constOp.value()); 599 } 600 601 LogicalResult 602 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { 603 unsigned pos; 604 // Pre-condition for this method. 605 if (!findId(forOp.getInductionVar(), &pos)) { 606 assert(false && "Value not found"); 607 return failure(); 608 } 609 610 int64_t step = forOp.getStep(); 611 if (step != 1) { 612 if (!forOp.hasConstantLowerBound()) 613 LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated")); 614 else { 615 // Add constraints for the stride. 616 // (iv - lb) % step = 0 can be written as: 617 // (iv - lb) - step * q = 0 where q = (iv - lb) / step. 618 // Add local variable 'q' and add the above equality. 619 // The first constraint is q = (iv - lb) floordiv step 620 SmallVector<int64_t, 8> dividend(getNumCols(), 0); 621 int64_t lb = forOp.getConstantLowerBound(); 622 dividend[pos] = 1; 623 dividend.back() -= lb; 624 addLocalFloorDiv(dividend, step); 625 // Second constraint: (iv - lb) - step * q = 0. 626 SmallVector<int64_t, 8> eq(getNumCols(), 0); 627 eq[pos] = 1; 628 eq.back() -= lb; 629 // For the local var just added above. 630 eq[getNumCols() - 2] = -step; 631 addEquality(eq); 632 } 633 } 634 635 if (forOp.hasConstantLowerBound()) { 636 addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); 637 } else { 638 // Non-constant lower bound case. 639 if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), 640 forOp.getLowerBoundOperands()))) 641 return failure(); 642 } 643 644 if (forOp.hasConstantUpperBound()) { 645 addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); 646 return success(); 647 } 648 // Non-constant upper bound case. 649 return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), 650 forOp.getUpperBoundOperands()); 651 } 652 653 LogicalResult 654 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, 655 ArrayRef<AffineMap> ubMaps, 656 ArrayRef<Value> operands) { 657 assert(lbMaps.size() == ubMaps.size()); 658 assert(lbMaps.size() <= getNumDimIds()); 659 660 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 661 AffineMap lbMap = lbMaps[i]; 662 AffineMap ubMap = ubMaps[i]; 663 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 664 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 665 666 // Check if this slice is just an equality along this dimension. If so, 667 // retrieve the existing loop it equates to and add it to the system. 668 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 669 ubMap.getNumResults() == 1 && 670 lbMap.getResult(0) + 1 == ubMap.getResult(0) && 671 // The condition above will be true for maps describing a single 672 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). 673 // Make sure we skip those cases by checking that the lb result is not 674 // just a constant. 675 !lbMap.getResult(0).isa<AffineConstantExpr>()) { 676 // Limited support: we expect the lb result to be just a loop dimension. 677 // Not supported otherwise for now. 678 AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>(); 679 if (!result) 680 return failure(); 681 682 AffineForOp loop = 683 getForInductionVarOwner(operands[result.getPosition()]); 684 if (!loop) 685 return failure(); 686 687 if (failed(addAffineForOpDomain(loop))) 688 return failure(); 689 continue; 690 } 691 692 // This slice refers to a loop that doesn't exist in the IR yet. Add its 693 // bounds to the system assuming its dimension identifier position is the 694 // same as the position of the loop in the loop nest. 695 if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands))) 696 return failure(); 697 if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands))) 698 return failure(); 699 } 700 return success(); 701 } 702 703 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { 704 // Create the base constraints from the integer set attached to ifOp. 705 FlatAffineValueConstraints cst(ifOp.getIntegerSet()); 706 707 // Bind ids in the constraints to ifOp operands. 708 SmallVector<Value, 4> operands = ifOp.getOperands(); 709 cst.setValues(0, cst.getNumDimAndSymbolIds(), operands); 710 711 // Merge the constraints from ifOp to the current domain. We need first merge 712 // and align the IDs from both constraints, and then append the constraints 713 // from the ifOp into the current one. 714 mergeAndAlignIdsWithOther(0, &cst); 715 append(cst); 716 } 717 718 bool FlatAffineValueConstraints::hasConsistentState() const { 719 return FlatAffineConstraints::hasConsistentState() && 720 values.size() == getNumIds(); 721 } 722 723 void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart, 724 unsigned idLimit) { 725 FlatAffineConstraints::removeIdRange(kind, idStart, idLimit); 726 unsigned offset = getIdKindOffset(kind); 727 values.erase(values.begin() + idStart + offset, 728 values.begin() + idLimit + offset); 729 } 730 731 // Determine whether the identifier at 'pos' (say id_r) can be expressed as 732 // modulo of another known identifier (say id_n) w.r.t a constant. For example, 733 // if the following constraints hold true: 734 // ``` 735 // 0 <= id_r <= divisor - 1 736 // id_n - (divisor * q_expr) = id_r 737 // ``` 738 // where `id_n` is a known identifier (called dividend), and `q_expr` is an 739 // `AffineExpr` (called the quotient expression), `id_r` can be written as: 740 // 741 // `id_r = id_n mod divisor`. 742 // 743 // Additionally, in a special case of the above constaints where `q_expr` is an 744 // identifier itself that is not yet known (say `id_q`), it can be written as a 745 // floordiv in the following way: 746 // 747 // `id_q = id_n floordiv divisor`. 748 // 749 // Returns true if the above mod or floordiv are detected, updating 'memo' with 750 // these new expressions. Returns false otherwise. 751 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, 752 int64_t lbConst, int64_t ubConst, 753 SmallVectorImpl<AffineExpr> &memo, 754 MLIRContext *context) { 755 assert(pos < cst.getNumIds() && "invalid position"); 756 757 // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can 758 // be determined. 759 if (lbConst != 0 || ubConst < 1) 760 return false; 761 int64_t divisor = ubConst + 1; 762 763 // Check for the aforementioned conditions in each equality. 764 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); 765 curEquality < numEqualities; curEquality++) { 766 int64_t coefficientAtPos = cst.atEq(curEquality, pos); 767 // If current equality does not involve `id_r`, continue to the next 768 // equality. 769 if (coefficientAtPos == 0) 770 continue; 771 772 // Constant term should be 0 in this equality. 773 if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0) 774 continue; 775 776 // Traverse through the equality and construct the dividend expression 777 // `dividendExpr`, to contain all the identifiers which are known and are 778 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the 779 // `dividendExpr` gets simplified into a single identifier `id_n` discussed 780 // above. 781 auto dividendExpr = getAffineConstantExpr(0, context); 782 783 // Track the terms that go into quotient expression, later used to detect 784 // additional floordiv. 785 unsigned quotientCount = 0; 786 int quotientPosition = -1; 787 int quotientSign = 1; 788 789 // Consider each term in the current equality. 790 unsigned curId, e; 791 for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) { 792 // Ignore id_r. 793 if (curId == pos) 794 continue; 795 int64_t coefficientOfCurId = cst.atEq(curEquality, curId); 796 // Ignore ids that do not contribute to the current equality. 797 if (coefficientOfCurId == 0) 798 continue; 799 // Check if the current id goes into the quotient expression. 800 if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) { 801 quotientCount++; 802 quotientPosition = curId; 803 quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1; 804 continue; 805 } 806 // Identifiers that are part of dividendExpr should be known. 807 if (!memo[curId]) 808 break; 809 // Append the current identifier to the dividend expression. 810 dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId; 811 } 812 813 // Can't construct expression as it depends on a yet uncomputed id. 814 if (curId < e) 815 continue; 816 817 // Express `id_r` in terms of the other ids collected so far. 818 if (coefficientAtPos > 0) 819 dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); 820 else 821 dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); 822 823 // Simplify the expression. 824 dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(), 825 cst.getNumSymbolIds()); 826 // Only if the final dividend expression is just a single id (which we call 827 // `id_n`), we can proceed. 828 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it 829 // to dims themselves. 830 auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>(); 831 if (!dimExpr) 832 continue; 833 834 // Express `id_r` as `id_n % divisor` and store the expression in `memo`. 835 if (quotientCount >= 1) { 836 auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB, 837 dimExpr.getPosition()); 838 // If `id_n` has an upperbound that is less than the divisor, mod can be 839 // eliminated altogether. 840 if (ub.hasValue() && ub.getValue() < divisor) 841 memo[pos] = dimExpr; 842 else 843 memo[pos] = dimExpr % divisor; 844 // If a unique quotient `id_q` was seen, it can be expressed as 845 // `id_n floordiv divisor`. 846 if (quotientCount == 1 && !memo[quotientPosition]) 847 memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; 848 849 return true; 850 } 851 } 852 return false; 853 } 854 855 /// Check if the pos^th identifier can be expressed as a floordiv of an affine 856 /// function of other identifiers (where the divisor is a positive constant) 857 /// given the initial set of expressions in `exprs`. If it can be, the 858 /// corresponding position in `exprs` is set as the detected affine expr. For 859 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can 860 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 861 /// <= i <= 32q + 31 => q = i floordiv 32. 862 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, 863 MLIRContext *context, 864 SmallVectorImpl<AffineExpr> &exprs) { 865 assert(pos < cst.getNumIds() && "invalid position"); 866 867 // Get upper-lower bound pair for this variable. 868 SmallVector<bool, 8> foundRepr(cst.getNumIds(), false); 869 for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i) 870 if (exprs[i]) 871 foundRepr[i] = true; 872 873 SmallVector<int64_t, 8> dividend; 874 unsigned divisor; 875 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); 876 877 // No upper-lower bound pair found for this var. 878 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) 879 return false; 880 881 // Construct the dividend expression. 882 auto dividendExpr = getAffineConstantExpr(dividend.back(), context); 883 for (unsigned c = 0, f = cst.getNumIds(); c < f; c++) 884 if (dividend[c] != 0) 885 dividendExpr = dividendExpr + dividend[c] * exprs[c]; 886 887 // Successfully detected the floordiv. 888 exprs[pos] = dividendExpr.floorDiv(divisor); 889 return true; 890 } 891 892 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound( 893 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 894 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const { 895 assert(pos + offset < getNumDimIds() && "invalid dim start pos"); 896 assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 897 assert(getNumLocalIds() == localExprs.size() && 898 "incorrect local exprs count"); 899 900 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 901 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices, 902 offset, num); 903 904 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 905 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 906 b.clear(); 907 for (unsigned i = 0, e = a.size(); i < e; ++i) { 908 if (i < offset || i >= offset + num) 909 b.push_back(a[i]); 910 } 911 }; 912 913 SmallVector<int64_t, 8> lb, ub; 914 SmallVector<AffineExpr, 4> lbExprs; 915 unsigned dimCount = symStartPos - num; 916 unsigned symCount = getNumDimAndSymbolIds() - symStartPos; 917 lbExprs.reserve(lbIndices.size() + eqIndices.size()); 918 // Lower bound expressions. 919 for (auto idx : lbIndices) { 920 auto ineq = getInequality(idx); 921 // Extract the lower bound (in terms of other coeff's + const), i.e., if 922 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 923 // - 1. 924 addCoeffs(ineq, lb); 925 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 926 auto expr = 927 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); 928 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor 929 int64_t divisor = std::abs(ineq[pos + offset]); 930 expr = (expr + divisor - 1).floorDiv(divisor); 931 lbExprs.push_back(expr); 932 } 933 934 SmallVector<AffineExpr, 4> ubExprs; 935 ubExprs.reserve(ubIndices.size() + eqIndices.size()); 936 // Upper bound expressions. 937 for (auto idx : ubIndices) { 938 auto ineq = getInequality(idx); 939 // Extract the upper bound (in terms of other coeff's + const). 940 addCoeffs(ineq, ub); 941 auto expr = 942 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); 943 expr = expr.floorDiv(std::abs(ineq[pos + offset])); 944 // Upper bound is exclusive. 945 ubExprs.push_back(expr + 1); 946 } 947 948 // Equalities. It's both a lower and a upper bound. 949 SmallVector<int64_t, 4> b; 950 for (auto idx : eqIndices) { 951 auto eq = getEquality(idx); 952 addCoeffs(eq, b); 953 if (eq[pos + offset] > 0) 954 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>()); 955 956 // Extract the upper bound (in terms of other coeff's + const). 957 auto expr = 958 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 959 expr = expr.floorDiv(std::abs(eq[pos + offset])); 960 // Upper bound is exclusive. 961 ubExprs.push_back(expr + 1); 962 // Lower bound. 963 expr = 964 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 965 expr = expr.ceilDiv(std::abs(eq[pos + offset])); 966 lbExprs.push_back(expr); 967 } 968 969 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); 970 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); 971 972 return {lbMap, ubMap}; 973 } 974 975 /// Computes the lower and upper bounds of the first 'num' dimensional 976 /// identifiers (starting at 'offset') as affine maps of the remaining 977 /// identifiers (dimensional and symbolic identifiers). Local identifiers are 978 /// themselves explicitly computed as affine functions of other identifiers in 979 /// this process if needed. 980 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, 981 MLIRContext *context, 982 SmallVectorImpl<AffineMap> *lbMaps, 983 SmallVectorImpl<AffineMap> *ubMaps) { 984 assert(num < getNumDimIds() && "invalid range"); 985 986 // Basic simplification. 987 normalizeConstraintsByGCD(); 988 989 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 990 << " identifiers\n"); 991 LLVM_DEBUG(dump()); 992 993 // Record computed/detected identifiers. 994 SmallVector<AffineExpr, 8> memo(getNumIds()); 995 // Initialize dimensional and symbolic identifiers. 996 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) { 997 if (i < offset) 998 memo[i] = getAffineDimExpr(i, context); 999 else if (i >= offset + num) 1000 memo[i] = getAffineDimExpr(i - num, context); 1001 } 1002 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) 1003 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); 1004 1005 bool changed; 1006 do { 1007 changed = false; 1008 // Identify yet unknown identifiers as constants or mod's / floordiv's of 1009 // other identifiers if possible. 1010 for (unsigned pos = 0; pos < getNumIds(); pos++) { 1011 if (memo[pos]) 1012 continue; 1013 1014 auto lbConst = getConstantBound(BoundType::LB, pos); 1015 auto ubConst = getConstantBound(BoundType::UB, pos); 1016 if (lbConst.hasValue() && ubConst.hasValue()) { 1017 // Detect equality to a constant. 1018 if (lbConst.getValue() == ubConst.getValue()) { 1019 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); 1020 changed = true; 1021 continue; 1022 } 1023 1024 // Detect an identifier as modulo of another identifier w.r.t a 1025 // constant. 1026 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), 1027 memo, context)) { 1028 changed = true; 1029 continue; 1030 } 1031 } 1032 1033 // Detect an identifier as a floordiv of an affine function of other 1034 // identifiers (divisor is a positive constant). 1035 if (detectAsFloorDiv(*this, pos, context, memo)) { 1036 changed = true; 1037 continue; 1038 } 1039 1040 // Detect an identifier as an expression of other identifiers. 1041 unsigned idx; 1042 if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { 1043 continue; 1044 } 1045 1046 // Build AffineExpr solving for identifier 'pos' in terms of all others. 1047 auto expr = getAffineConstantExpr(0, context); 1048 unsigned j, e; 1049 for (j = 0, e = getNumIds(); j < e; ++j) { 1050 if (j == pos) 1051 continue; 1052 int64_t c = atEq(idx, j); 1053 if (c == 0) 1054 continue; 1055 // If any of the involved IDs hasn't been found yet, we can't proceed. 1056 if (!memo[j]) 1057 break; 1058 expr = expr + memo[j] * c; 1059 } 1060 if (j < e) 1061 // Can't construct expression as it depends on a yet uncomputed 1062 // identifier. 1063 continue; 1064 1065 // Add constant term to AffineExpr. 1066 expr = expr + atEq(idx, getNumIds()); 1067 int64_t vPos = atEq(idx, pos); 1068 assert(vPos != 0 && "expected non-zero here"); 1069 if (vPos > 0) 1070 expr = (-expr).floorDiv(vPos); 1071 else 1072 // vPos < 0. 1073 expr = expr.floorDiv(-vPos); 1074 // Successfully constructed expression. 1075 memo[pos] = expr; 1076 changed = true; 1077 } 1078 // This loop is guaranteed to reach a fixed point - since once an 1079 // identifier's explicit form is computed (in memo[pos]), it's not updated 1080 // again. 1081 } while (changed); 1082 1083 // Set the lower and upper bound maps for all the identifiers that were 1084 // computed as affine expressions of the rest as the "detected expr" and 1085 // "detected expr + 1" respectively; set the undetected ones to null. 1086 Optional<FlatAffineConstraints> tmpClone; 1087 for (unsigned pos = 0; pos < num; pos++) { 1088 unsigned numMapDims = getNumDimIds() - num; 1089 unsigned numMapSymbols = getNumSymbolIds(); 1090 AffineExpr expr = memo[pos + offset]; 1091 if (expr) 1092 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 1093 1094 AffineMap &lbMap = (*lbMaps)[pos]; 1095 AffineMap &ubMap = (*ubMaps)[pos]; 1096 1097 if (expr) { 1098 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 1099 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1); 1100 } else { 1101 // TODO: Whenever there are local identifiers in the dependence 1102 // constraints, we'll conservatively over-approximate, since we don't 1103 // always explicitly compute them above (in the while loop). 1104 if (getNumLocalIds() == 0) { 1105 // Work on a copy so that we don't update this constraint system. 1106 if (!tmpClone) { 1107 tmpClone.emplace(FlatAffineConstraints(*this)); 1108 // Removing redundant inequalities is necessary so that we don't get 1109 // redundant loop bounds. 1110 tmpClone->removeRedundantInequalities(); 1111 } 1112 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 1113 pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context); 1114 } 1115 1116 // If the above fails, we'll just use the constant lower bound and the 1117 // constant upper bound (if they exist) as the slice bounds. 1118 // TODO: being conservative for the moment in cases that 1119 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 1120 // fixed (b/126426796). 1121 if (!lbMap || lbMap.getNumResults() > 1) { 1122 LLVM_DEBUG(llvm::dbgs() 1123 << "WARNING: Potentially over-approximating slice lb\n"); 1124 auto lbConst = getConstantBound(BoundType::LB, pos + offset); 1125 if (lbConst.hasValue()) { 1126 lbMap = AffineMap::get( 1127 numMapDims, numMapSymbols, 1128 getAffineConstantExpr(lbConst.getValue(), context)); 1129 } 1130 } 1131 if (!ubMap || ubMap.getNumResults() > 1) { 1132 LLVM_DEBUG(llvm::dbgs() 1133 << "WARNING: Potentially over-approximating slice ub\n"); 1134 auto ubConst = getConstantBound(BoundType::UB, pos + offset); 1135 if (ubConst.hasValue()) { 1136 (ubMap) = AffineMap::get( 1137 numMapDims, numMapSymbols, 1138 getAffineConstantExpr(ubConst.getValue() + 1, context)); 1139 } 1140 } 1141 } 1142 LLVM_DEBUG(llvm::dbgs() 1143 << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 1144 LLVM_DEBUG(lbMap.dump();); 1145 LLVM_DEBUG(llvm::dbgs() 1146 << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 1147 LLVM_DEBUG(ubMap.dump();); 1148 } 1149 } 1150 1151 LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals( 1152 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { 1153 FlatAffineConstraints localCst; 1154 if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) { 1155 LLVM_DEBUG(llvm::dbgs() 1156 << "composition unimplemented for semi-affine maps\n"); 1157 return failure(); 1158 } 1159 1160 // Add localCst information. 1161 if (localCst.getNumLocalIds() > 0) { 1162 unsigned numLocalIds = getNumLocalIds(); 1163 // Insert local dims of localCst at the beginning. 1164 insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds()); 1165 // Insert local dims of `this` at the end of localCst. 1166 localCst.appendLocalId(/*num=*/numLocalIds); 1167 // Dimensions of localCst and this constraint set match. Append localCst to 1168 // this constraint set. 1169 append(localCst); 1170 } 1171 1172 return success(); 1173 } 1174 1175 LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos, 1176 AffineMap boundMap) { 1177 assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); 1178 assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); 1179 assert(pos < getNumDimAndSymbolIds() && "invalid position"); 1180 1181 // Equality follows the logic of lower bound except that we add an equality 1182 // instead of an inequality. 1183 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && 1184 "single result expected"); 1185 bool lower = type == BoundType::LB || type == BoundType::EQ; 1186 1187 std::vector<SmallVector<int64_t, 8>> flatExprs; 1188 if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs))) 1189 return failure(); 1190 assert(flatExprs.size() == boundMap.getNumResults()); 1191 1192 // Add one (in)equality for each result. 1193 for (const auto &flatExpr : flatExprs) { 1194 SmallVector<int64_t> ineq(getNumCols(), 0); 1195 // Dims and symbols. 1196 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { 1197 ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; 1198 } 1199 // Invalid bound: pos appears in `boundMap`. 1200 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or 1201 // its callers to prevent invalid bounds from being added. 1202 if (ineq[pos] != 0) 1203 continue; 1204 ineq[pos] = lower ? 1 : -1; 1205 // Local columns of `ineq` are at the beginning. 1206 unsigned j = getNumDimIds() + getNumSymbolIds(); 1207 unsigned end = flatExpr.size() - 1; 1208 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { 1209 ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; 1210 } 1211 // Constant term. 1212 ineq[getNumCols() - 1] = 1213 lower ? -flatExpr[flatExpr.size() - 1] 1214 // Upper bound in flattenedExpr is an exclusive one. 1215 : flatExpr[flatExpr.size() - 1] - 1; 1216 type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); 1217 } 1218 1219 return success(); 1220 } 1221 1222 AffineMap 1223 FlatAffineValueConstraints::computeAlignedMap(AffineMap map, 1224 ValueRange operands) const { 1225 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); 1226 1227 SmallVector<Value> dims, syms; 1228 #ifndef NDEBUG 1229 SmallVector<Value> newSyms; 1230 SmallVector<Value> *newSymsPtr = &newSyms; 1231 #else 1232 SmallVector<Value> *newSymsPtr = nullptr; 1233 #endif // NDEBUG 1234 1235 dims.reserve(getNumDimIds()); 1236 syms.reserve(getNumSymbolIds()); 1237 for (unsigned i = getIdKindOffset(IdKind::SetDim), 1238 e = getIdKindEnd(IdKind::SetDim); 1239 i < e; ++i) 1240 dims.push_back(values[i] ? *values[i] : Value()); 1241 for (unsigned i = getIdKindOffset(IdKind::Symbol), 1242 e = getIdKindEnd(IdKind::Symbol); 1243 i < e; ++i) 1244 syms.push_back(values[i] ? *values[i] : Value()); 1245 1246 AffineMap alignedMap = 1247 alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); 1248 // All symbols are already part of this FlatAffineConstraints. 1249 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); 1250 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && 1251 "unexpected new/missing symbols"); 1252 return alignedMap; 1253 } 1254 1255 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1256 AffineMap boundMap, 1257 ValueRange boundOperands) { 1258 // Fully compose map and operands; canonicalize and simplify so that we 1259 // transitively get to terminal symbols or loop IVs. 1260 auto map = boundMap; 1261 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); 1262 fullyComposeAffineMapAndOperands(&map, &operands); 1263 map = simplifyAffineMap(map); 1264 canonicalizeMapAndOperands(&map, &operands); 1265 for (auto operand : operands) 1266 addInductionVarOrTerminalSymbol(operand); 1267 return addBound(type, pos, computeAlignedMap(map, operands)); 1268 } 1269 1270 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper 1271 // bounds in 'ubMaps' to each value in `values' that appears in the constraint 1272 // system. Note that both lower/upper bounds share the same operand list 1273 // 'operands'. 1274 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and 1275 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. 1276 // Note that both lower/upper bounds use operands from 'operands'. 1277 // Returns failure for unimplemented cases such as semi-affine expressions or 1278 // expressions with mod/floordiv. 1279 LogicalResult FlatAffineValueConstraints::addSliceBounds( 1280 ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps, 1281 ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) { 1282 assert(values.size() == lbMaps.size()); 1283 assert(lbMaps.size() == ubMaps.size()); 1284 1285 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 1286 unsigned pos; 1287 if (!findId(values[i], &pos)) 1288 continue; 1289 1290 AffineMap lbMap = lbMaps[i]; 1291 AffineMap ubMap = ubMaps[i]; 1292 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 1293 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 1294 1295 // Check if this slice is just an equality along this dimension. 1296 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 1297 ubMap.getNumResults() == 1 && 1298 lbMap.getResult(0) + 1 == ubMap.getResult(0)) { 1299 if (failed(addBound(BoundType::EQ, pos, lbMap, operands))) 1300 return failure(); 1301 continue; 1302 } 1303 1304 // If lower or upper bound maps are null or provide no results, it implies 1305 // that the source loop was not at all sliced, and the entire loop will be a 1306 // part of the slice. 1307 if (lbMap && lbMap.getNumResults() != 0 && ubMap && 1308 ubMap.getNumResults() != 0) { 1309 if (failed(addBound(BoundType::LB, pos, lbMap, operands))) 1310 return failure(); 1311 if (failed(addBound(BoundType::UB, pos, ubMap, operands))) 1312 return failure(); 1313 } else { 1314 auto loop = getForInductionVarOwner(values[i]); 1315 if (failed(this->addAffineForOpDomain(loop))) 1316 return failure(); 1317 } 1318 } 1319 return success(); 1320 } 1321 1322 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const { 1323 unsigned i = 0; 1324 for (const auto &mayBeId : values) { 1325 if (mayBeId.hasValue() && mayBeId.getValue() == val) { 1326 *pos = i; 1327 return true; 1328 } 1329 i++; 1330 } 1331 return false; 1332 } 1333 1334 bool FlatAffineValueConstraints::containsId(Value val) const { 1335 return llvm::any_of(values, [&](const Optional<Value> &mayBeId) { 1336 return mayBeId.hasValue() && mayBeId.getValue() == val; 1337 }); 1338 } 1339 1340 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) { 1341 FlatAffineConstraints::swapId(posA, posB); 1342 std::swap(values[posA], values[posB]); 1343 } 1344 1345 void FlatAffineValueConstraints::addBound(BoundType type, Value val, 1346 int64_t value) { 1347 unsigned pos; 1348 if (!findId(val, &pos)) 1349 // This is a pre-condition for this method. 1350 assert(0 && "id not found"); 1351 addBound(type, pos, value); 1352 } 1353 1354 void FlatAffineConstraints::printSpace(raw_ostream &os) const { 1355 IntegerPolyhedron::printSpace(os); 1356 os << "("; 1357 for (unsigned i = 0, e = getNumIds(); i < e; i++) { 1358 if (auto *valueCstr = dyn_cast<const FlatAffineValueConstraints>(this)) { 1359 if (valueCstr->hasValue(i)) 1360 os << "Value "; 1361 else 1362 os << "None "; 1363 } else { 1364 os << "None "; 1365 } 1366 } 1367 os << " const)\n"; 1368 } 1369 1370 void FlatAffineConstraints::clearAndCopyFrom(const IntegerRelation &other) { 1371 if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other)) 1372 assert(!otherValueSet->hasValues() && 1373 "cannot copy associated Values into FlatAffineConstraints"); 1374 1375 // Note: Assigment operator does not vtable pointer, so kind does not 1376 // change. 1377 if (auto *otherValueSet = dyn_cast<const FlatAffineConstraints>(&other)) 1378 *this = *otherValueSet; 1379 else 1380 *static_cast<IntegerRelation *>(this) = other; 1381 } 1382 1383 void FlatAffineValueConstraints::clearAndCopyFrom( 1384 const IntegerRelation &other) { 1385 1386 if (auto *otherValueSet = 1387 dyn_cast<const FlatAffineValueConstraints>(&other)) { 1388 *this = *otherValueSet; 1389 return; 1390 } 1391 1392 if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other)) 1393 *static_cast<FlatAffineConstraints *>(this) = *otherValueSet; 1394 else 1395 *static_cast<IntegerRelation *>(this) = other; 1396 1397 values.clear(); 1398 values.resize(getNumIds(), None); 1399 } 1400 1401 void FlatAffineValueConstraints::fourierMotzkinEliminate( 1402 unsigned pos, bool darkShadow, bool *isResultIntegerExact) { 1403 SmallVector<Optional<Value>, 8> newVals; 1404 newVals.reserve(getNumIds() - 1); 1405 newVals.append(values.begin(), values.begin() + pos); 1406 newVals.append(values.begin() + pos + 1, values.end()); 1407 // Note: Base implementation discards all associated Values. 1408 FlatAffineConstraints::fourierMotzkinEliminate(pos, darkShadow, 1409 isResultIntegerExact); 1410 values = newVals; 1411 assert(values.size() == getNumIds()); 1412 } 1413 1414 void FlatAffineValueConstraints::projectOut(Value val) { 1415 unsigned pos; 1416 bool ret = findId(val, &pos); 1417 assert(ret); 1418 (void)ret; 1419 fourierMotzkinEliminate(pos); 1420 } 1421 1422 LogicalResult FlatAffineValueConstraints::unionBoundingBox( 1423 const FlatAffineValueConstraints &otherCst) { 1424 assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); 1425 assert(otherCst.getMaybeValues() 1426 .slice(0, getNumDimIds()) 1427 .equals(getMaybeValues().slice(0, getNumDimIds())) && 1428 "dim values mismatch"); 1429 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); 1430 assert(getNumLocalIds() == 0 && "local ids not supported yet here"); 1431 1432 // Align `other` to this. 1433 if (!areIdsAligned(*this, otherCst)) { 1434 FlatAffineValueConstraints otherCopy(otherCst); 1435 mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy); 1436 return FlatAffineConstraints::unionBoundingBox(otherCopy); 1437 } 1438 1439 return FlatAffineConstraints::unionBoundingBox(otherCst); 1440 } 1441 1442 /// Compute an explicit representation for local vars. For all systems coming 1443 /// from MLIR integer sets, maps, or expressions where local vars were 1444 /// introduced to model floordivs and mods, this always succeeds. 1445 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst, 1446 SmallVectorImpl<AffineExpr> &memo, 1447 MLIRContext *context) { 1448 unsigned numDims = cst.getNumDimIds(); 1449 unsigned numSyms = cst.getNumSymbolIds(); 1450 1451 // Initialize dimensional and symbolic identifiers. 1452 for (unsigned i = 0; i < numDims; i++) 1453 memo[i] = getAffineDimExpr(i, context); 1454 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) 1455 memo[i] = getAffineSymbolExpr(i - numDims, context); 1456 1457 bool changed; 1458 do { 1459 // Each time `changed` is true at the end of this iteration, one or more 1460 // local vars would have been detected as floordivs and set in memo; so the 1461 // number of null entries in memo[...] strictly reduces; so this converges. 1462 changed = false; 1463 for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i) 1464 if (!memo[numDims + numSyms + i] && 1465 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo)) 1466 changed = true; 1467 } while (changed); 1468 1469 ArrayRef<AffineExpr> localExprs = 1470 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds()); 1471 return success( 1472 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); 1473 } 1474 1475 void FlatAffineValueConstraints::getIneqAsAffineValueMap( 1476 unsigned pos, unsigned ineqPos, AffineValueMap &vmap, 1477 MLIRContext *context) const { 1478 unsigned numDims = getNumDimIds(); 1479 unsigned numSyms = getNumSymbolIds(); 1480 1481 assert(pos < numDims && "invalid position"); 1482 assert(ineqPos < getNumInequalities() && "invalid inequality position"); 1483 1484 // Get expressions for local vars. 1485 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr()); 1486 if (failed(computeLocalVars(*this, memo, context))) 1487 assert(false && 1488 "one or more local exprs do not have an explicit representation"); 1489 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds()); 1490 1491 // Compute the AffineExpr lower/upper bound for this inequality. 1492 ArrayRef<int64_t> inequality = getInequality(ineqPos); 1493 SmallVector<int64_t, 8> bound; 1494 bound.reserve(getNumCols() - 1); 1495 // Everything other than the coefficient at `pos`. 1496 bound.append(inequality.begin(), inequality.begin() + pos); 1497 bound.append(inequality.begin() + pos + 1, inequality.end()); 1498 1499 if (inequality[pos] > 0) 1500 // Lower bound. 1501 std::transform(bound.begin(), bound.end(), bound.begin(), 1502 std::negate<int64_t>()); 1503 else 1504 // Upper bound (which is exclusive). 1505 bound.back() += 1; 1506 1507 // Convert to AffineExpr (tree) form. 1508 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms, 1509 localExprs, context); 1510 1511 // Get the values to bind to this affine expr (all dims and symbols). 1512 SmallVector<Value, 4> operands; 1513 getValues(0, pos, &operands); 1514 SmallVector<Value, 4> trailingOperands; 1515 getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands); 1516 operands.append(trailingOperands.begin(), trailingOperands.end()); 1517 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands); 1518 } 1519 1520 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const { 1521 if (getNumConstraints() == 0) 1522 // Return universal set (always true): 0 == 0. 1523 return IntegerSet::get(getNumDimIds(), getNumSymbolIds(), 1524 getAffineConstantExpr(/*constant=*/0, context), 1525 /*eqFlags=*/true); 1526 1527 // Construct local references. 1528 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr()); 1529 1530 if (failed(computeLocalVars(*this, memo, context))) { 1531 // Check if the local variables without an explicit representation have 1532 // zero coefficients everywhere. 1533 SmallVector<unsigned> noLocalRepVars; 1534 unsigned numDimsSymbols = getNumDimAndSymbolIds(); 1535 for (unsigned i = numDimsSymbols, e = getNumIds(); i < e; ++i) { 1536 if (!memo[i] && !isColZero(/*pos=*/i)) 1537 noLocalRepVars.push_back(i - numDimsSymbols); 1538 } 1539 if (!noLocalRepVars.empty()) { 1540 LLVM_DEBUG({ 1541 llvm::dbgs() << "local variables at position(s) "; 1542 llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); 1543 llvm::dbgs() << " do not have an explicit representation in:\n"; 1544 this->dump(); 1545 }); 1546 return IntegerSet(); 1547 } 1548 } 1549 1550 ArrayRef<AffineExpr> localExprs = 1551 ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds()); 1552 1553 // Construct the IntegerSet from the equalities/inequalities. 1554 unsigned numDims = getNumDimIds(); 1555 unsigned numSyms = getNumSymbolIds(); 1556 1557 SmallVector<bool, 16> eqFlags(getNumConstraints()); 1558 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); 1559 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); 1560 1561 SmallVector<AffineExpr, 8> exprs; 1562 exprs.reserve(getNumConstraints()); 1563 1564 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1565 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms, 1566 localExprs, context)); 1567 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 1568 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims, 1569 numSyms, localExprs, context)); 1570 return IntegerSet::get(numDims, numSyms, exprs, eqFlags); 1571 } 1572 1573 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, 1574 ValueRange dims, ValueRange syms, 1575 SmallVector<Value> *newSyms) { 1576 assert(operands.size() == map.getNumInputs() && 1577 "expected same number of operands and map inputs"); 1578 MLIRContext *ctx = map.getContext(); 1579 Builder builder(ctx); 1580 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); 1581 unsigned numSymbols = syms.size(); 1582 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); 1583 if (newSyms) { 1584 newSyms->clear(); 1585 newSyms->append(syms.begin(), syms.end()); 1586 } 1587 1588 for (const auto &operand : llvm::enumerate(operands)) { 1589 // Compute replacement dim/sym of operand. 1590 AffineExpr replacement; 1591 auto dimIt = std::find(dims.begin(), dims.end(), operand.value()); 1592 auto symIt = std::find(syms.begin(), syms.end(), operand.value()); 1593 if (dimIt != dims.end()) { 1594 replacement = 1595 builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); 1596 } else if (symIt != syms.end()) { 1597 replacement = 1598 builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); 1599 } else { 1600 // This operand is neither a dimension nor a symbol. Add it as a new 1601 // symbol. 1602 replacement = builder.getAffineSymbolExpr(numSymbols++); 1603 if (newSyms) 1604 newSyms->push_back(operand.value()); 1605 } 1606 // Add to corresponding replacements vector. 1607 if (operand.index() < map.getNumDims()) { 1608 dimReplacements[operand.index()] = replacement; 1609 } else { 1610 symReplacements[operand.index() - map.getNumDims()] = replacement; 1611 } 1612 } 1613 1614 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, 1615 dims.size(), numSymbols); 1616 } 1617 1618 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { 1619 FlatAffineValueConstraints domain = *this; 1620 // Convert all range variables to local variables. 1621 domain.convertDimToLocal(getNumDomainDims(), 1622 getNumDomainDims() + getNumRangeDims()); 1623 return domain; 1624 } 1625 1626 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { 1627 FlatAffineValueConstraints range = *this; 1628 // Convert all domain variables to local variables. 1629 range.convertDimToLocal(0, getNumDomainDims()); 1630 return range; 1631 } 1632 1633 void FlatAffineRelation::compose(const FlatAffineRelation &other) { 1634 assert(getNumDomainDims() == other.getNumRangeDims() && 1635 "Domain of this and range of other do not match"); 1636 assert(std::equal(values.begin(), values.begin() + getNumDomainDims(), 1637 other.values.begin() + other.getNumDomainDims()) && 1638 "Domain of this and range of other do not match"); 1639 1640 FlatAffineRelation rel = other; 1641 1642 // Convert `rel` from 1643 // [otherDomain] -> [otherRange] 1644 // to 1645 // [otherDomain] -> [otherRange thisRange] 1646 // and `this` from 1647 // [thisDomain] -> [thisRange] 1648 // to 1649 // [otherDomain thisDomain] -> [thisRange]. 1650 unsigned removeDims = rel.getNumRangeDims(); 1651 insertDomainId(0, rel.getNumDomainDims()); 1652 rel.appendRangeId(getNumRangeDims()); 1653 1654 // Merge symbol and local identifiers. 1655 mergeSymbolIds(rel); 1656 mergeLocalIds(rel); 1657 1658 // Convert `rel` from [otherDomain] -> [otherRange thisRange] to 1659 // [otherDomain] -> [thisRange] by converting first otherRange range ids 1660 // to local ids. 1661 rel.convertDimToLocal(rel.getNumDomainDims(), 1662 rel.getNumDomainDims() + removeDims); 1663 // Convert `this` from [otherDomain thisDomain] -> [thisRange] to 1664 // [otherDomain] -> [thisRange] by converting last thisDomain domain ids 1665 // to local ids. 1666 convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims()); 1667 1668 auto thisMaybeValues = getMaybeDimValues(); 1669 auto relMaybeValues = rel.getMaybeDimValues(); 1670 1671 // Add and match domain of `rel` to domain of `this`. 1672 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1673 if (relMaybeValues[i].hasValue()) 1674 setValue(i, relMaybeValues[i].getValue()); 1675 // Add and match range of `this` to range of `rel`. 1676 for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { 1677 unsigned rangeIdx = rel.getNumDomainDims() + i; 1678 if (thisMaybeValues[rangeIdx].hasValue()) 1679 rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue()); 1680 } 1681 1682 // Append `this` to `rel` and simplify constraints. 1683 rel.append(*this); 1684 rel.removeRedundantLocalVars(); 1685 1686 *this = rel; 1687 } 1688 1689 void FlatAffineRelation::inverse() { 1690 unsigned oldDomain = getNumDomainDims(); 1691 unsigned oldRange = getNumRangeDims(); 1692 // Add new range ids. 1693 appendRangeId(oldDomain); 1694 // Swap new ids with domain. 1695 for (unsigned i = 0; i < oldDomain; ++i) 1696 swapId(i, oldDomain + oldRange + i); 1697 // Remove the swapped domain. 1698 removeIdRange(0, oldDomain); 1699 // Set domain and range as inverse. 1700 numDomainDims = oldRange; 1701 numRangeDims = oldDomain; 1702 } 1703 1704 void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) { 1705 assert(pos <= getNumDomainDims() && 1706 "Id cannot be inserted at invalid position"); 1707 insertDimId(pos, num); 1708 numDomainDims += num; 1709 } 1710 1711 void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) { 1712 assert(pos <= getNumRangeDims() && 1713 "Id cannot be inserted at invalid position"); 1714 insertDimId(getNumDomainDims() + pos, num); 1715 numRangeDims += num; 1716 } 1717 1718 void FlatAffineRelation::appendDomainId(unsigned num) { 1719 insertDimId(getNumDomainDims(), num); 1720 numDomainDims += num; 1721 } 1722 1723 void FlatAffineRelation::appendRangeId(unsigned num) { 1724 insertDimId(getNumDimIds(), num); 1725 numRangeDims += num; 1726 } 1727 1728 void FlatAffineRelation::removeIdRange(IdKind kind, unsigned idStart, 1729 unsigned idLimit) { 1730 assert(idLimit <= getNumIdKind(kind)); 1731 if (idStart >= idLimit) 1732 return; 1733 1734 FlatAffineValueConstraints::removeIdRange(kind, idStart, idLimit); 1735 1736 // If kind is not SetDim, domain and range don't need to be updated. 1737 if (kind != IdKind::SetDim) 1738 return; 1739 1740 // Compute number of domain and range identifiers to remove. This is done by 1741 // intersecting the range of domain/range ids with range of ids to remove. 1742 unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims()); 1743 unsigned intersectDomainRHS = idStart; 1744 unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds()); 1745 unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims()); 1746 1747 if (intersectDomainLHS > intersectDomainRHS) 1748 numDomainDims -= intersectDomainLHS - intersectDomainRHS; 1749 if (intersectRangeLHS > intersectRangeRHS) 1750 numRangeDims -= intersectRangeLHS - intersectRangeRHS; 1751 } 1752 1753 LogicalResult mlir::getRelationFromMap(AffineMap &map, 1754 FlatAffineRelation &rel) { 1755 // Get flattened affine expressions. 1756 std::vector<SmallVector<int64_t, 8>> flatExprs; 1757 FlatAffineConstraints localVarCst; 1758 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) 1759 return failure(); 1760 1761 unsigned oldDimNum = localVarCst.getNumDimIds(); 1762 unsigned oldCols = localVarCst.getNumCols(); 1763 unsigned numRangeIds = map.getNumResults(); 1764 unsigned numDomainIds = map.getNumDims(); 1765 1766 // Add range as the new expressions. 1767 localVarCst.appendDimId(numRangeIds); 1768 1769 // Add equalities between source and range. 1770 SmallVector<int64_t, 8> eq(localVarCst.getNumCols()); 1771 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 1772 // Zero fill. 1773 std::fill(eq.begin(), eq.end(), 0); 1774 // Fill equality. 1775 for (unsigned j = 0, f = oldDimNum; j < f; ++j) 1776 eq[j] = flatExprs[i][j]; 1777 for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) 1778 eq[j + numRangeIds] = flatExprs[i][j]; 1779 // Set this dimension to -1 to equate lhs and rhs and add equality. 1780 eq[numDomainIds + i] = -1; 1781 localVarCst.addEquality(eq); 1782 } 1783 1784 // Create relation and return success. 1785 rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst); 1786 return success(); 1787 } 1788 1789 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map, 1790 FlatAffineRelation &rel) { 1791 1792 AffineMap affineMap = map.getAffineMap(); 1793 if (failed(getRelationFromMap(affineMap, rel))) 1794 return failure(); 1795 1796 // Set symbol values for domain dimensions and symbols. 1797 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1798 rel.setValue(i, map.getOperand(i)); 1799 for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e; 1800 ++i) 1801 rel.setValue(i, map.getOperand(i - rel.getNumRangeDims())); 1802 1803 return success(); 1804 } 1805