1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Structures for affine/polyhedral analysis of affine dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 14 #include "mlir/Analysis/Presburger/LinearTransform.h" 15 #include "mlir/Analysis/Presburger/Simplex.h" 16 #include "mlir/Analysis/Presburger/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Support/MathExtras.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #define DEBUG_TYPE "affine-structures" 31 32 using namespace mlir; 33 using namespace presburger; 34 35 namespace { 36 37 // See comments for SimpleAffineExprFlattener. 38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording 39 // constraint information associated with mod's, floordiv's, and ceildiv's 40 // in FlatAffineValueConstraints 'localVarCst'. 41 struct AffineExprFlattener : public SimpleAffineExprFlattener { 42 public: 43 // Constraints connecting newly introduced local variables (for mod's and 44 // div's) to existing (dimensional and symbolic) ones. These are always 45 // inequalities. 46 IntegerPolyhedron localVarCst; 47 48 AffineExprFlattener(unsigned nDims, unsigned nSymbols) 49 : SimpleAffineExprFlattener(nDims, nSymbols), 50 localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {} 51 52 private: 53 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 54 // The local identifier added is always a floordiv of a pure add/mul affine 55 // function of other identifiers, coefficients of which are specified in 56 // `dividend' and with respect to the positive constant `divisor'. localExpr 57 // is the simplified tree expression (AffineExpr) corresponding to the 58 // quantifier. 59 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 60 AffineExpr localExpr) override { 61 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 62 // Update localVarCst. 63 localVarCst.addLocalFloorDiv(dividend, divisor); 64 } 65 }; 66 67 } // namespace 68 69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be 70 // flattened (i.e., semi-affine expressions not handled yet). 71 static LogicalResult 72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, 73 unsigned numSymbols, 74 std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 75 FlatAffineValueConstraints *localVarCst) { 76 if (exprs.empty()) { 77 localVarCst->reset(numDims, numSymbols); 78 return success(); 79 } 80 81 AffineExprFlattener flattener(numDims, numSymbols); 82 // Use the same flattener to simplify each expression successively. This way 83 // local identifiers / expressions are shared. 84 for (auto expr : exprs) { 85 if (!expr.isPureAffine()) 86 return failure(); 87 88 flattener.walkPostOrder(expr); 89 } 90 91 assert(flattener.operandExprStack.size() == exprs.size()); 92 flattenedExprs->clear(); 93 flattenedExprs->assign(flattener.operandExprStack.begin(), 94 flattener.operandExprStack.end()); 95 96 if (localVarCst) 97 localVarCst->clearAndCopyFrom(flattener.localVarCst); 98 99 return success(); 100 } 101 102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 103 // be flattened (semi-affine expressions not handled yet). 104 LogicalResult 105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, 106 unsigned numSymbols, 107 SmallVectorImpl<int64_t> *flattenedExpr, 108 FlatAffineValueConstraints *localVarCst) { 109 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 110 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, 111 &flattenedExprs, localVarCst); 112 *flattenedExpr = flattenedExprs[0]; 113 return ret; 114 } 115 116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 117 /// flattened (i.e., semi-affine expressions not handled yet). 118 LogicalResult mlir::getFlattenedAffineExprs( 119 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 120 FlatAffineValueConstraints *localVarCst) { 121 if (map.getNumResults() == 0) { 122 localVarCst->reset(map.getNumDims(), map.getNumSymbols()); 123 return success(); 124 } 125 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), 126 map.getNumSymbols(), flattenedExprs, 127 localVarCst); 128 } 129 130 LogicalResult mlir::getFlattenedAffineExprs( 131 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 132 FlatAffineValueConstraints *localVarCst) { 133 if (set.getNumConstraints() == 0) { 134 localVarCst->reset(set.getNumDims(), set.getNumSymbols()); 135 return success(); 136 } 137 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 138 set.getNumSymbols(), flattenedExprs, 139 localVarCst); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // FlatAffineConstraints / FlatAffineValueConstraints. 144 //===----------------------------------------------------------------------===// 145 146 std::unique_ptr<FlatAffineValueConstraints> 147 FlatAffineValueConstraints::clone() const { 148 return std::make_unique<FlatAffineValueConstraints>(*this); 149 } 150 151 // Construct from an IntegerSet. 152 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set) 153 : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(), 154 set.getNumDims() + set.getNumSymbols() + 1, 155 PresburgerSpace::getSetSpace(set.getNumDims(), 156 set.getNumSymbols(), 157 /*numLocals=*/0)) { 158 159 // Resize values. 160 values.resize(getNumDimAndSymbolIds(), None); 161 162 // Flatten expressions and add them to the constraint system. 163 std::vector<SmallVector<int64_t, 8>> flatExprs; 164 FlatAffineValueConstraints localVarCst; 165 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 166 assert(false && "flattening unimplemented for semi-affine integer sets"); 167 return; 168 } 169 assert(flatExprs.size() == set.getNumConstraints()); 170 insertId(IdKind::Local, getNumIdKind(IdKind::Local), 171 /*num=*/localVarCst.getNumLocalIds()); 172 173 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 174 const auto &flatExpr = flatExprs[i]; 175 assert(flatExpr.size() == getNumCols()); 176 if (set.getEqFlags()[i]) { 177 addEquality(flatExpr); 178 } else { 179 addInequality(flatExpr); 180 } 181 } 182 // Add the other constraints involving local id's from flattening. 183 append(localVarCst); 184 } 185 186 // Construct a hyperrectangular constraint set from ValueRanges that represent 187 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are 188 // expected to match one to one. The order of variables and constraints is: 189 // 190 // ivs | lbs | ubs | eq/ineq 191 // ----+-----+-----+--------- 192 // 1 -1 0 >= 0 193 // ----+-----+-----+--------- 194 // -1 0 1 >= 0 195 // 196 // All dimensions as set as DimId. 197 FlatAffineValueConstraints 198 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs, 199 ValueRange ubs) { 200 FlatAffineValueConstraints res; 201 unsigned nIvs = ivs.size(); 202 assert(nIvs == lbs.size() && "expected as many lower bounds as ivs"); 203 assert(nIvs == ubs.size() && "expected as many upper bounds as ivs"); 204 205 if (nIvs == 0) 206 return res; 207 208 res.appendDimId(ivs); 209 unsigned lbsStart = res.appendDimId(lbs); 210 unsigned ubsStart = res.appendDimId(ubs); 211 212 MLIRContext *ctx = ivs.front().getContext(); 213 for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) { 214 // iv - lb >= 0 215 AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 216 getAffineDimExpr(lbsStart + ivIdx, ctx)); 217 if (failed(res.addBound(BoundType::LB, ivIdx, lb))) 218 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 219 // -iv + ub >= 0 220 AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0, 221 getAffineDimExpr(ubsStart + ivIdx, ctx)); 222 if (failed(res.addBound(BoundType::UB, ivIdx, ub))) 223 llvm_unreachable("Unexpected FlatAffineValueConstraints creation error"); 224 } 225 return res; 226 } 227 228 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities, 229 unsigned numReservedEqualities, 230 unsigned newNumReservedCols, 231 unsigned newNumDims, 232 unsigned newNumSymbols, 233 unsigned newNumLocals) { 234 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 235 "minimum 1 column"); 236 *this = FlatAffineValueConstraints(numReservedInequalities, 237 numReservedEqualities, newNumReservedCols, 238 newNumDims, newNumSymbols, newNumLocals); 239 } 240 241 void FlatAffineValueConstraints::reset(unsigned newNumDims, 242 unsigned newNumSymbols, 243 unsigned newNumLocals) { 244 reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, 245 /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1, 246 newNumDims, newNumSymbols, newNumLocals); 247 } 248 249 void FlatAffineValueConstraints::reset( 250 unsigned numReservedInequalities, unsigned numReservedEqualities, 251 unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, 252 unsigned newNumLocals, ArrayRef<Value> valArgs) { 253 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 254 "minimum 1 column"); 255 SmallVector<Optional<Value>, 8> newVals; 256 if (!valArgs.empty()) 257 newVals.assign(valArgs.begin(), valArgs.end()); 258 259 *this = FlatAffineValueConstraints( 260 numReservedInequalities, numReservedEqualities, newNumReservedCols, 261 newNumDims, newNumSymbols, newNumLocals, newVals); 262 } 263 264 void FlatAffineValueConstraints::reset(unsigned newNumDims, 265 unsigned newNumSymbols, 266 unsigned newNumLocals, 267 ArrayRef<Value> valArgs) { 268 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, 269 newNumSymbols, newNumLocals, valArgs); 270 } 271 272 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) { 273 unsigned pos = getNumDimIds(); 274 insertId(IdKind::SetDim, pos, vals); 275 return pos; 276 } 277 278 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) { 279 unsigned pos = getNumSymbolIds(); 280 insertId(IdKind::Symbol, pos, vals); 281 return pos; 282 } 283 284 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos, 285 ValueRange vals) { 286 return insertId(IdKind::SetDim, pos, vals); 287 } 288 289 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos, 290 ValueRange vals) { 291 return insertId(IdKind::Symbol, pos, vals); 292 } 293 294 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos, 295 unsigned num) { 296 unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num); 297 298 if (kind != IdKind::Local) { 299 values.insert(values.begin() + absolutePos, num, None); 300 assert(values.size() == getNumDimAndSymbolIds()); 301 } 302 303 return absolutePos; 304 } 305 306 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos, 307 ValueRange vals) { 308 assert(!vals.empty() && "expected ValueRange with Values."); 309 assert(kind != IdKind::Local && 310 "values cannot be attached to local identifiers."); 311 unsigned num = vals.size(); 312 unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num); 313 314 // If a Value is provided, insert it; otherwise use None. 315 for (unsigned i = 0; i < num; ++i) 316 values.insert(values.begin() + absolutePos + i, 317 vals[i] ? Optional<Value>(vals[i]) : None); 318 319 assert(values.size() == getNumDimAndSymbolIds()); 320 return absolutePos; 321 } 322 323 bool FlatAffineValueConstraints::hasValues() const { 324 return llvm::find_if(values, [](Optional<Value> id) { 325 return id.hasValue(); 326 }) != values.end(); 327 } 328 329 /// Checks if two constraint systems are in the same space, i.e., if they are 330 /// associated with the same set of identifiers, appearing in the same order. 331 static bool areIdsAligned(const FlatAffineValueConstraints &a, 332 const FlatAffineValueConstraints &b) { 333 return a.getNumDimIds() == b.getNumDimIds() && 334 a.getNumSymbolIds() == b.getNumSymbolIds() && 335 a.getNumIds() == b.getNumIds() && 336 a.getMaybeValues().equals(b.getMaybeValues()); 337 } 338 339 /// Calls areIdsAligned to check if two constraint systems have the same set 340 /// of identifiers in the same order. 341 bool FlatAffineValueConstraints::areIdsAlignedWithOther( 342 const FlatAffineValueConstraints &other) { 343 return areIdsAligned(*this, other); 344 } 345 346 /// Checks if the SSA values associated with `cst`'s identifiers in range 347 /// [start, end) are unique. 348 static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( 349 const FlatAffineValueConstraints &cst, unsigned start, unsigned end) { 350 351 assert(start <= cst.getNumDimAndSymbolIds() && 352 "Start position out of bounds"); 353 assert(end <= cst.getNumDimAndSymbolIds() && "End position out of bounds"); 354 355 if (start >= end) 356 return true; 357 358 SmallPtrSet<Value, 8> uniqueIds; 359 ArrayRef<Optional<Value>> maybeValues = 360 cst.getMaybeValues().slice(start, end - start); 361 for (Optional<Value> val : maybeValues) { 362 if (val.hasValue() && !uniqueIds.insert(val.getValue()).second) 363 return false; 364 } 365 return true; 366 } 367 368 /// Checks if the SSA values associated with `cst`'s identifiers are unique. 369 static bool LLVM_ATTRIBUTE_UNUSED 370 areIdsUnique(const FlatAffineValueConstraints &cst) { 371 return areIdsUnique(cst, 0, cst.getNumDimAndSymbolIds()); 372 } 373 374 /// Checks if the SSA values associated with `cst`'s identifiers of kind `kind` 375 /// are unique. 376 static bool LLVM_ATTRIBUTE_UNUSED 377 areIdsUnique(const FlatAffineValueConstraints &cst, IdKind kind) { 378 379 if (kind == IdKind::SetDim) 380 return areIdsUnique(cst, 0, cst.getNumDimIds()); 381 if (kind == IdKind::Symbol) 382 return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds()); 383 llvm_unreachable("Unexpected IdKind"); 384 } 385 386 /// Merge and align the identifiers of A and B starting at 'offset', so that 387 /// both constraint systems get the union of the contained identifiers that is 388 /// dimension-wise and symbol-wise unique; both constraint systems are updated 389 /// so that they have the union of all identifiers, with A's original 390 /// identifiers appearing first followed by any of B's identifiers that didn't 391 /// appear in A. Local identifiers in B that have the same division 392 /// representation as local identifiers in A are merged into one. 393 // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) 394 // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 395 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a, 396 FlatAffineValueConstraints *b) { 397 assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds()); 398 // A merge/align isn't meaningful if a cst's ids aren't distinct. 399 assert(areIdsUnique(*a) && "A's values aren't unique"); 400 assert(areIdsUnique(*b) && "B's values aren't unique"); 401 402 assert(std::all_of(a->getMaybeValues().begin() + offset, 403 a->getMaybeValues().end(), 404 [](Optional<Value> id) { return id.hasValue(); })); 405 406 assert(std::all_of(b->getMaybeValues().begin() + offset, 407 b->getMaybeValues().end(), 408 [](Optional<Value> id) { return id.hasValue(); })); 409 410 SmallVector<Value, 4> aDimValues; 411 a->getValues(offset, a->getNumDimIds(), &aDimValues); 412 413 { 414 // Merge dims from A into B. 415 unsigned d = offset; 416 for (auto aDimValue : aDimValues) { 417 unsigned loc; 418 if (b->findId(aDimValue, &loc)) { 419 assert(loc >= offset && "A's dim appears in B's aligned range"); 420 assert(loc < b->getNumDimIds() && 421 "A's dim appears in B's non-dim position"); 422 b->swapId(d, loc); 423 } else { 424 b->insertDimId(d, aDimValue); 425 } 426 d++; 427 } 428 // Dimensions that are in B, but not in A, are added at the end. 429 for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) { 430 a->appendDimId(b->getValue(t)); 431 } 432 assert(a->getNumDimIds() == b->getNumDimIds() && 433 "expected same number of dims"); 434 } 435 436 // Merge and align symbols of A and B 437 a->mergeSymbolIds(*b); 438 // Merge and align local ids of A and B 439 a->mergeLocalIds(*b); 440 441 assert(areIdsAligned(*a, *b) && "IDs expected to be aligned"); 442 } 443 444 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. 445 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther( 446 unsigned offset, FlatAffineValueConstraints *other) { 447 mergeAndAlignIds(offset, this, other); 448 } 449 450 LogicalResult 451 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) { 452 return composeMatchingMap( 453 computeAlignedMap(vMap->getAffineMap(), vMap->getOperands())); 454 } 455 456 // Similar to `composeMap` except that no Values need be associated with the 457 // constraint system nor are they looked at -- the dimensions and symbols of 458 // `other` are expected to correspond 1:1 to `this` system. 459 LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) { 460 assert(other.getNumDims() == getNumDimIds() && "dim mismatch"); 461 assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); 462 463 std::vector<SmallVector<int64_t, 8>> flatExprs; 464 if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) 465 return failure(); 466 assert(flatExprs.size() == other.getNumResults()); 467 468 // Add dimensions corresponding to the map's results. 469 insertDimId(/*pos=*/0, /*num=*/other.getNumResults()); 470 471 // We add one equality for each result connecting the result dim of the map to 472 // the other identifiers. 473 // E.g.: if the expression is 16*i0 + i1, and this is the r^th 474 // iteration/result of the value map, we are adding the equality: 475 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we 476 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 477 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 478 const auto &flatExpr = flatExprs[r]; 479 assert(flatExpr.size() >= other.getNumInputs() + 1); 480 481 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 482 // Set the coefficient for this result to one. 483 eqToAdd[r] = 1; 484 485 // Dims and symbols. 486 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { 487 // Negate `eq[r]` since the newly added dimension will be set to this one. 488 eqToAdd[e + i] = -flatExpr[i]; 489 } 490 // Local columns of `eq` are at the beginning. 491 unsigned j = getNumDimIds() + getNumSymbolIds(); 492 unsigned end = flatExpr.size() - 1; 493 for (unsigned i = other.getNumInputs(); i < end; i++, j++) { 494 eqToAdd[j] = -flatExpr[i]; 495 } 496 497 // Constant term. 498 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 499 500 // Add the equality connecting the result of the map to this constraint set. 501 addEquality(eqToAdd); 502 } 503 504 return success(); 505 } 506 507 // Turn a symbol into a dimension. 508 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) { 509 unsigned pos; 510 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && 511 pos < cst->getNumDimAndSymbolIds()) { 512 cst->swapId(pos, cst->getNumDimIds()); 513 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1); 514 } 515 } 516 517 /// Merge and align symbols of `this` and `other` such that both get union of 518 /// of symbols that are unique. Symbols in `this` and `other` should be 519 /// unique. Symbols with Value as `None` are considered to be inequal to all 520 /// other symbols. 521 void FlatAffineValueConstraints::mergeSymbolIds( 522 FlatAffineValueConstraints &other) { 523 524 assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); 525 assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); 526 527 SmallVector<Value, 4> aSymValues; 528 getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); 529 530 // Merge symbols: merge symbols into `other` first from `this`. 531 unsigned s = other.getNumDimIds(); 532 for (Value aSymValue : aSymValues) { 533 unsigned loc; 534 // If the id is a symbol in `other`, then align it, otherwise assume that 535 // it is a new symbol 536 if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && 537 loc < other.getNumDimAndSymbolIds()) 538 other.swapId(s, loc); 539 else 540 other.insertSymbolId(s - other.getNumDimIds(), aSymValue); 541 s++; 542 } 543 544 // Symbols that are in other, but not in this, are added at the end. 545 for (unsigned t = other.getNumDimIds() + getNumSymbolIds(), 546 e = other.getNumDimAndSymbolIds(); 547 t < e; t++) 548 insertSymbolId(getNumSymbolIds(), other.getValue(t)); 549 550 assert(getNumSymbolIds() == other.getNumSymbolIds() && 551 "expected same number of symbols"); 552 assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); 553 assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); 554 } 555 556 // Changes all symbol identifiers which are loop IVs to dim identifiers. 557 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { 558 // Gather all symbols which are loop IVs. 559 SmallVector<Value, 4> loopIVs; 560 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { 561 if (hasValue(i) && getForInductionVarOwner(getValue(i))) 562 loopIVs.push_back(getValue(i)); 563 } 564 // Turn each symbol in 'loopIVs' into a dim identifier. 565 for (auto iv : loopIVs) { 566 turnSymbolIntoDim(this, iv); 567 } 568 } 569 570 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { 571 if (containsId(val)) 572 return; 573 574 // Caller is expected to fully compose map/operands if necessary. 575 assert((isTopLevelValue(val) || isForInductionVar(val)) && 576 "non-terminal symbol / loop IV expected"); 577 // Outer loop IVs could be used in forOp's bounds. 578 if (auto loop = getForInductionVarOwner(val)) { 579 appendDimId(val); 580 if (failed(this->addAffineForOpDomain(loop))) 581 LLVM_DEBUG( 582 loop.emitWarning("failed to add domain info to constraint system")); 583 return; 584 } 585 // Add top level symbol. 586 appendSymbolId(val); 587 // Check if the symbol is a constant. 588 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) 589 addBound(BoundType::EQ, val, constOp.value()); 590 } 591 592 LogicalResult 593 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { 594 unsigned pos; 595 // Pre-condition for this method. 596 if (!findId(forOp.getInductionVar(), &pos)) { 597 assert(false && "Value not found"); 598 return failure(); 599 } 600 601 int64_t step = forOp.getStep(); 602 if (step != 1) { 603 if (!forOp.hasConstantLowerBound()) 604 LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated")); 605 else { 606 // Add constraints for the stride. 607 // (iv - lb) % step = 0 can be written as: 608 // (iv - lb) - step * q = 0 where q = (iv - lb) / step. 609 // Add local variable 'q' and add the above equality. 610 // The first constraint is q = (iv - lb) floordiv step 611 SmallVector<int64_t, 8> dividend(getNumCols(), 0); 612 int64_t lb = forOp.getConstantLowerBound(); 613 dividend[pos] = 1; 614 dividend.back() -= lb; 615 addLocalFloorDiv(dividend, step); 616 // Second constraint: (iv - lb) - step * q = 0. 617 SmallVector<int64_t, 8> eq(getNumCols(), 0); 618 eq[pos] = 1; 619 eq.back() -= lb; 620 // For the local var just added above. 621 eq[getNumCols() - 2] = -step; 622 addEquality(eq); 623 } 624 } 625 626 if (forOp.hasConstantLowerBound()) { 627 addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); 628 } else { 629 // Non-constant lower bound case. 630 if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), 631 forOp.getLowerBoundOperands()))) 632 return failure(); 633 } 634 635 if (forOp.hasConstantUpperBound()) { 636 addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); 637 return success(); 638 } 639 // Non-constant upper bound case. 640 return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), 641 forOp.getUpperBoundOperands()); 642 } 643 644 LogicalResult 645 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, 646 ArrayRef<AffineMap> ubMaps, 647 ArrayRef<Value> operands) { 648 assert(lbMaps.size() == ubMaps.size()); 649 assert(lbMaps.size() <= getNumDimIds()); 650 651 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 652 AffineMap lbMap = lbMaps[i]; 653 AffineMap ubMap = ubMaps[i]; 654 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 655 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 656 657 // Check if this slice is just an equality along this dimension. If so, 658 // retrieve the existing loop it equates to and add it to the system. 659 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 660 ubMap.getNumResults() == 1 && 661 lbMap.getResult(0) + 1 == ubMap.getResult(0) && 662 // The condition above will be true for maps describing a single 663 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). 664 // Make sure we skip those cases by checking that the lb result is not 665 // just a constant. 666 !lbMap.getResult(0).isa<AffineConstantExpr>()) { 667 // Limited support: we expect the lb result to be just a loop dimension. 668 // Not supported otherwise for now. 669 AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>(); 670 if (!result) 671 return failure(); 672 673 AffineForOp loop = 674 getForInductionVarOwner(operands[result.getPosition()]); 675 if (!loop) 676 return failure(); 677 678 if (failed(addAffineForOpDomain(loop))) 679 return failure(); 680 continue; 681 } 682 683 // This slice refers to a loop that doesn't exist in the IR yet. Add its 684 // bounds to the system assuming its dimension identifier position is the 685 // same as the position of the loop in the loop nest. 686 if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands))) 687 return failure(); 688 if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands))) 689 return failure(); 690 } 691 return success(); 692 } 693 694 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { 695 // Create the base constraints from the integer set attached to ifOp. 696 FlatAffineValueConstraints cst(ifOp.getIntegerSet()); 697 698 // Bind ids in the constraints to ifOp operands. 699 SmallVector<Value, 4> operands = ifOp.getOperands(); 700 cst.setValues(0, cst.getNumDimAndSymbolIds(), operands); 701 702 // Merge the constraints from ifOp to the current domain. We need first merge 703 // and align the IDs from both constraints, and then append the constraints 704 // from the ifOp into the current one. 705 mergeAndAlignIdsWithOther(0, &cst); 706 append(cst); 707 } 708 709 bool FlatAffineValueConstraints::hasConsistentState() const { 710 return IntegerPolyhedron::hasConsistentState() && 711 values.size() == getNumDimAndSymbolIds(); 712 } 713 714 void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart, 715 unsigned idLimit) { 716 IntegerPolyhedron::removeIdRange(kind, idStart, idLimit); 717 unsigned offset = getIdKindOffset(kind); 718 719 if (kind != IdKind::Local) { 720 values.erase(values.begin() + idStart + offset, 721 values.begin() + idLimit + offset); 722 } 723 } 724 725 // Determine whether the identifier at 'pos' (say id_r) can be expressed as 726 // modulo of another known identifier (say id_n) w.r.t a constant. For example, 727 // if the following constraints hold true: 728 // ``` 729 // 0 <= id_r <= divisor - 1 730 // id_n - (divisor * q_expr) = id_r 731 // ``` 732 // where `id_n` is a known identifier (called dividend), and `q_expr` is an 733 // `AffineExpr` (called the quotient expression), `id_r` can be written as: 734 // 735 // `id_r = id_n mod divisor`. 736 // 737 // Additionally, in a special case of the above constaints where `q_expr` is an 738 // identifier itself that is not yet known (say `id_q`), it can be written as a 739 // floordiv in the following way: 740 // 741 // `id_q = id_n floordiv divisor`. 742 // 743 // Returns true if the above mod or floordiv are detected, updating 'memo' with 744 // these new expressions. Returns false otherwise. 745 static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos, 746 int64_t lbConst, int64_t ubConst, 747 SmallVectorImpl<AffineExpr> &memo, 748 MLIRContext *context) { 749 assert(pos < cst.getNumIds() && "invalid position"); 750 751 // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can 752 // be determined. 753 if (lbConst != 0 || ubConst < 1) 754 return false; 755 int64_t divisor = ubConst + 1; 756 757 // Check for the aforementioned conditions in each equality. 758 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); 759 curEquality < numEqualities; curEquality++) { 760 int64_t coefficientAtPos = cst.atEq(curEquality, pos); 761 // If current equality does not involve `id_r`, continue to the next 762 // equality. 763 if (coefficientAtPos == 0) 764 continue; 765 766 // Constant term should be 0 in this equality. 767 if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0) 768 continue; 769 770 // Traverse through the equality and construct the dividend expression 771 // `dividendExpr`, to contain all the identifiers which are known and are 772 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the 773 // `dividendExpr` gets simplified into a single identifier `id_n` discussed 774 // above. 775 auto dividendExpr = getAffineConstantExpr(0, context); 776 777 // Track the terms that go into quotient expression, later used to detect 778 // additional floordiv. 779 unsigned quotientCount = 0; 780 int quotientPosition = -1; 781 int quotientSign = 1; 782 783 // Consider each term in the current equality. 784 unsigned curId, e; 785 for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) { 786 // Ignore id_r. 787 if (curId == pos) 788 continue; 789 int64_t coefficientOfCurId = cst.atEq(curEquality, curId); 790 // Ignore ids that do not contribute to the current equality. 791 if (coefficientOfCurId == 0) 792 continue; 793 // Check if the current id goes into the quotient expression. 794 if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) { 795 quotientCount++; 796 quotientPosition = curId; 797 quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1; 798 continue; 799 } 800 // Identifiers that are part of dividendExpr should be known. 801 if (!memo[curId]) 802 break; 803 // Append the current identifier to the dividend expression. 804 dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId; 805 } 806 807 // Can't construct expression as it depends on a yet uncomputed id. 808 if (curId < e) 809 continue; 810 811 // Express `id_r` in terms of the other ids collected so far. 812 if (coefficientAtPos > 0) 813 dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); 814 else 815 dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); 816 817 // Simplify the expression. 818 dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(), 819 cst.getNumSymbolIds()); 820 // Only if the final dividend expression is just a single id (which we call 821 // `id_n`), we can proceed. 822 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it 823 // to dims themselves. 824 auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>(); 825 if (!dimExpr) 826 continue; 827 828 // Express `id_r` as `id_n % divisor` and store the expression in `memo`. 829 if (quotientCount >= 1) { 830 auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB, 831 dimExpr.getPosition()); 832 // If `id_n` has an upperbound that is less than the divisor, mod can be 833 // eliminated altogether. 834 if (ub.hasValue() && ub.getValue() < divisor) 835 memo[pos] = dimExpr; 836 else 837 memo[pos] = dimExpr % divisor; 838 // If a unique quotient `id_q` was seen, it can be expressed as 839 // `id_n floordiv divisor`. 840 if (quotientCount == 1 && !memo[quotientPosition]) 841 memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; 842 843 return true; 844 } 845 } 846 return false; 847 } 848 849 /// Check if the pos^th identifier can be expressed as a floordiv of an affine 850 /// function of other identifiers (where the divisor is a positive constant) 851 /// given the initial set of expressions in `exprs`. If it can be, the 852 /// corresponding position in `exprs` is set as the detected affine expr. For 853 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can 854 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 855 /// <= i <= 32q + 31 => q = i floordiv 32. 856 static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst, 857 unsigned pos, MLIRContext *context, 858 SmallVectorImpl<AffineExpr> &exprs) { 859 assert(pos < cst.getNumIds() && "invalid position"); 860 861 // Get upper-lower bound pair for this variable. 862 SmallVector<bool, 8> foundRepr(cst.getNumIds(), false); 863 for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i) 864 if (exprs[i]) 865 foundRepr[i] = true; 866 867 SmallVector<int64_t, 8> dividend; 868 unsigned divisor; 869 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); 870 871 // No upper-lower bound pair found for this var. 872 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) 873 return false; 874 875 // Construct the dividend expression. 876 auto dividendExpr = getAffineConstantExpr(dividend.back(), context); 877 for (unsigned c = 0, f = cst.getNumIds(); c < f; c++) 878 if (dividend[c] != 0) 879 dividendExpr = dividendExpr + dividend[c] * exprs[c]; 880 881 // Successfully detected the floordiv. 882 exprs[pos] = dividendExpr.floorDiv(divisor); 883 return true; 884 } 885 886 std::pair<AffineMap, AffineMap> 887 FlatAffineValueConstraints::getLowerAndUpperBound( 888 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 889 ArrayRef<AffineExpr> localExprs, MLIRContext *context) const { 890 assert(pos + offset < getNumDimIds() && "invalid dim start pos"); 891 assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 892 assert(getNumLocalIds() == localExprs.size() && 893 "incorrect local exprs count"); 894 895 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 896 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices, 897 offset, num); 898 899 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 900 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 901 b.clear(); 902 for (unsigned i = 0, e = a.size(); i < e; ++i) { 903 if (i < offset || i >= offset + num) 904 b.push_back(a[i]); 905 } 906 }; 907 908 SmallVector<int64_t, 8> lb, ub; 909 SmallVector<AffineExpr, 4> lbExprs; 910 unsigned dimCount = symStartPos - num; 911 unsigned symCount = getNumDimAndSymbolIds() - symStartPos; 912 lbExprs.reserve(lbIndices.size() + eqIndices.size()); 913 // Lower bound expressions. 914 for (auto idx : lbIndices) { 915 auto ineq = getInequality(idx); 916 // Extract the lower bound (in terms of other coeff's + const), i.e., if 917 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 918 // - 1. 919 addCoeffs(ineq, lb); 920 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 921 auto expr = 922 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); 923 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor 924 int64_t divisor = std::abs(ineq[pos + offset]); 925 expr = (expr + divisor - 1).floorDiv(divisor); 926 lbExprs.push_back(expr); 927 } 928 929 SmallVector<AffineExpr, 4> ubExprs; 930 ubExprs.reserve(ubIndices.size() + eqIndices.size()); 931 // Upper bound expressions. 932 for (auto idx : ubIndices) { 933 auto ineq = getInequality(idx); 934 // Extract the upper bound (in terms of other coeff's + const). 935 addCoeffs(ineq, ub); 936 auto expr = 937 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); 938 expr = expr.floorDiv(std::abs(ineq[pos + offset])); 939 // Upper bound is exclusive. 940 ubExprs.push_back(expr + 1); 941 } 942 943 // Equalities. It's both a lower and a upper bound. 944 SmallVector<int64_t, 4> b; 945 for (auto idx : eqIndices) { 946 auto eq = getEquality(idx); 947 addCoeffs(eq, b); 948 if (eq[pos + offset] > 0) 949 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>()); 950 951 // Extract the upper bound (in terms of other coeff's + const). 952 auto expr = 953 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 954 expr = expr.floorDiv(std::abs(eq[pos + offset])); 955 // Upper bound is exclusive. 956 ubExprs.push_back(expr + 1); 957 // Lower bound. 958 expr = 959 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 960 expr = expr.ceilDiv(std::abs(eq[pos + offset])); 961 lbExprs.push_back(expr); 962 } 963 964 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); 965 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); 966 967 return {lbMap, ubMap}; 968 } 969 970 /// Computes the lower and upper bounds of the first 'num' dimensional 971 /// identifiers (starting at 'offset') as affine maps of the remaining 972 /// identifiers (dimensional and symbolic identifiers). Local identifiers are 973 /// themselves explicitly computed as affine functions of other identifiers in 974 /// this process if needed. 975 void FlatAffineValueConstraints::getSliceBounds( 976 unsigned offset, unsigned num, MLIRContext *context, 977 SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps, 978 bool getClosedUB) { 979 assert(num < getNumDimIds() && "invalid range"); 980 981 // Basic simplification. 982 normalizeConstraintsByGCD(); 983 984 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 985 << " identifiers\n"); 986 LLVM_DEBUG(dump()); 987 988 // Record computed/detected identifiers. 989 SmallVector<AffineExpr, 8> memo(getNumIds()); 990 // Initialize dimensional and symbolic identifiers. 991 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) { 992 if (i < offset) 993 memo[i] = getAffineDimExpr(i, context); 994 else if (i >= offset + num) 995 memo[i] = getAffineDimExpr(i - num, context); 996 } 997 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) 998 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); 999 1000 bool changed; 1001 do { 1002 changed = false; 1003 // Identify yet unknown identifiers as constants or mod's / floordiv's of 1004 // other identifiers if possible. 1005 for (unsigned pos = 0; pos < getNumIds(); pos++) { 1006 if (memo[pos]) 1007 continue; 1008 1009 auto lbConst = getConstantBound(BoundType::LB, pos); 1010 auto ubConst = getConstantBound(BoundType::UB, pos); 1011 if (lbConst.hasValue() && ubConst.hasValue()) { 1012 // Detect equality to a constant. 1013 if (lbConst.getValue() == ubConst.getValue()) { 1014 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); 1015 changed = true; 1016 continue; 1017 } 1018 1019 // Detect an identifier as modulo of another identifier w.r.t a 1020 // constant. 1021 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), 1022 memo, context)) { 1023 changed = true; 1024 continue; 1025 } 1026 } 1027 1028 // Detect an identifier as a floordiv of an affine function of other 1029 // identifiers (divisor is a positive constant). 1030 if (detectAsFloorDiv(*this, pos, context, memo)) { 1031 changed = true; 1032 continue; 1033 } 1034 1035 // Detect an identifier as an expression of other identifiers. 1036 unsigned idx; 1037 if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { 1038 continue; 1039 } 1040 1041 // Build AffineExpr solving for identifier 'pos' in terms of all others. 1042 auto expr = getAffineConstantExpr(0, context); 1043 unsigned j, e; 1044 for (j = 0, e = getNumIds(); j < e; ++j) { 1045 if (j == pos) 1046 continue; 1047 int64_t c = atEq(idx, j); 1048 if (c == 0) 1049 continue; 1050 // If any of the involved IDs hasn't been found yet, we can't proceed. 1051 if (!memo[j]) 1052 break; 1053 expr = expr + memo[j] * c; 1054 } 1055 if (j < e) 1056 // Can't construct expression as it depends on a yet uncomputed 1057 // identifier. 1058 continue; 1059 1060 // Add constant term to AffineExpr. 1061 expr = expr + atEq(idx, getNumIds()); 1062 int64_t vPos = atEq(idx, pos); 1063 assert(vPos != 0 && "expected non-zero here"); 1064 if (vPos > 0) 1065 expr = (-expr).floorDiv(vPos); 1066 else 1067 // vPos < 0. 1068 expr = expr.floorDiv(-vPos); 1069 // Successfully constructed expression. 1070 memo[pos] = expr; 1071 changed = true; 1072 } 1073 // This loop is guaranteed to reach a fixed point - since once an 1074 // identifier's explicit form is computed (in memo[pos]), it's not updated 1075 // again. 1076 } while (changed); 1077 1078 int64_t ubAdjustment = getClosedUB ? 0 : 1; 1079 1080 // Set the lower and upper bound maps for all the identifiers that were 1081 // computed as affine expressions of the rest as the "detected expr" and 1082 // "detected expr + 1" respectively; set the undetected ones to null. 1083 Optional<FlatAffineValueConstraints> tmpClone; 1084 for (unsigned pos = 0; pos < num; pos++) { 1085 unsigned numMapDims = getNumDimIds() - num; 1086 unsigned numMapSymbols = getNumSymbolIds(); 1087 AffineExpr expr = memo[pos + offset]; 1088 if (expr) 1089 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 1090 1091 AffineMap &lbMap = (*lbMaps)[pos]; 1092 AffineMap &ubMap = (*ubMaps)[pos]; 1093 1094 if (expr) { 1095 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 1096 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment); 1097 } else { 1098 // TODO: Whenever there are local identifiers in the dependence 1099 // constraints, we'll conservatively over-approximate, since we don't 1100 // always explicitly compute them above (in the while loop). 1101 if (getNumLocalIds() == 0) { 1102 // Work on a copy so that we don't update this constraint system. 1103 if (!tmpClone) { 1104 tmpClone.emplace(FlatAffineValueConstraints(*this)); 1105 // Removing redundant inequalities is necessary so that we don't get 1106 // redundant loop bounds. 1107 tmpClone->removeRedundantInequalities(); 1108 } 1109 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 1110 pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context); 1111 } 1112 1113 // If the above fails, we'll just use the constant lower bound and the 1114 // constant upper bound (if they exist) as the slice bounds. 1115 // TODO: being conservative for the moment in cases that 1116 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 1117 // fixed (b/126426796). 1118 if (!lbMap || lbMap.getNumResults() > 1) { 1119 LLVM_DEBUG(llvm::dbgs() 1120 << "WARNING: Potentially over-approximating slice lb\n"); 1121 auto lbConst = getConstantBound(BoundType::LB, pos + offset); 1122 if (lbConst.hasValue()) { 1123 lbMap = AffineMap::get( 1124 numMapDims, numMapSymbols, 1125 getAffineConstantExpr(lbConst.getValue(), context)); 1126 } 1127 } 1128 if (!ubMap || ubMap.getNumResults() > 1) { 1129 LLVM_DEBUG(llvm::dbgs() 1130 << "WARNING: Potentially over-approximating slice ub\n"); 1131 auto ubConst = getConstantBound(BoundType::UB, pos + offset); 1132 if (ubConst.hasValue()) { 1133 ubMap = 1134 AffineMap::get(numMapDims, numMapSymbols, 1135 getAffineConstantExpr( 1136 ubConst.getValue() + ubAdjustment, context)); 1137 } 1138 } 1139 } 1140 LLVM_DEBUG(llvm::dbgs() 1141 << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 1142 LLVM_DEBUG(lbMap.dump();); 1143 LLVM_DEBUG(llvm::dbgs() 1144 << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 1145 LLVM_DEBUG(ubMap.dump();); 1146 } 1147 } 1148 1149 LogicalResult FlatAffineValueConstraints::flattenAlignedMapAndMergeLocals( 1150 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { 1151 FlatAffineValueConstraints localCst; 1152 if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) { 1153 LLVM_DEBUG(llvm::dbgs() 1154 << "composition unimplemented for semi-affine maps\n"); 1155 return failure(); 1156 } 1157 1158 // Add localCst information. 1159 if (localCst.getNumLocalIds() > 0) { 1160 unsigned numLocalIds = getNumLocalIds(); 1161 // Insert local dims of localCst at the beginning. 1162 insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds()); 1163 // Insert local dims of `this` at the end of localCst. 1164 localCst.appendLocalId(/*num=*/numLocalIds); 1165 // Dimensions of localCst and this constraint set match. Append localCst to 1166 // this constraint set. 1167 append(localCst); 1168 } 1169 1170 return success(); 1171 } 1172 1173 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1174 AffineMap boundMap, 1175 bool isClosedBound) { 1176 assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); 1177 assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); 1178 assert(pos < getNumDimAndSymbolIds() && "invalid position"); 1179 assert((type != BoundType::EQ || isClosedBound) && 1180 "EQ bound must be closed."); 1181 1182 // Equality follows the logic of lower bound except that we add an equality 1183 // instead of an inequality. 1184 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && 1185 "single result expected"); 1186 bool lower = type == BoundType::LB || type == BoundType::EQ; 1187 1188 std::vector<SmallVector<int64_t, 8>> flatExprs; 1189 if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs))) 1190 return failure(); 1191 assert(flatExprs.size() == boundMap.getNumResults()); 1192 1193 // Add one (in)equality for each result. 1194 for (const auto &flatExpr : flatExprs) { 1195 SmallVector<int64_t> ineq(getNumCols(), 0); 1196 // Dims and symbols. 1197 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { 1198 ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; 1199 } 1200 // Invalid bound: pos appears in `boundMap`. 1201 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or 1202 // its callers to prevent invalid bounds from being added. 1203 if (ineq[pos] != 0) 1204 continue; 1205 ineq[pos] = lower ? 1 : -1; 1206 // Local columns of `ineq` are at the beginning. 1207 unsigned j = getNumDimIds() + getNumSymbolIds(); 1208 unsigned end = flatExpr.size() - 1; 1209 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { 1210 ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; 1211 } 1212 // Make the bound closed in if flatExpr is open. The inequality is always 1213 // created in the upper bound form, so the adjustment is -1. 1214 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; 1215 // Constant term. 1216 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] 1217 : flatExpr[flatExpr.size() - 1]) + 1218 boundAdjustment; 1219 type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); 1220 } 1221 1222 return success(); 1223 } 1224 1225 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1226 AffineMap boundMap) { 1227 return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB); 1228 } 1229 1230 AffineMap 1231 FlatAffineValueConstraints::computeAlignedMap(AffineMap map, 1232 ValueRange operands) const { 1233 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); 1234 1235 SmallVector<Value> dims, syms; 1236 #ifndef NDEBUG 1237 SmallVector<Value> newSyms; 1238 SmallVector<Value> *newSymsPtr = &newSyms; 1239 #else 1240 SmallVector<Value> *newSymsPtr = nullptr; 1241 #endif // NDEBUG 1242 1243 dims.reserve(getNumDimIds()); 1244 syms.reserve(getNumSymbolIds()); 1245 for (unsigned i = getIdKindOffset(IdKind::SetDim), 1246 e = getIdKindEnd(IdKind::SetDim); 1247 i < e; ++i) 1248 dims.push_back(values[i] ? *values[i] : Value()); 1249 for (unsigned i = getIdKindOffset(IdKind::Symbol), 1250 e = getIdKindEnd(IdKind::Symbol); 1251 i < e; ++i) 1252 syms.push_back(values[i] ? *values[i] : Value()); 1253 1254 AffineMap alignedMap = 1255 alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); 1256 // All symbols are already part of this FlatAffineConstraints. 1257 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); 1258 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && 1259 "unexpected new/missing symbols"); 1260 return alignedMap; 1261 } 1262 1263 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, 1264 AffineMap boundMap, 1265 ValueRange boundOperands) { 1266 // Fully compose map and operands; canonicalize and simplify so that we 1267 // transitively get to terminal symbols or loop IVs. 1268 auto map = boundMap; 1269 SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); 1270 fullyComposeAffineMapAndOperands(&map, &operands); 1271 map = simplifyAffineMap(map); 1272 canonicalizeMapAndOperands(&map, &operands); 1273 for (auto operand : operands) 1274 addInductionVarOrTerminalSymbol(operand); 1275 return addBound(type, pos, computeAlignedMap(map, operands)); 1276 } 1277 1278 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper 1279 // bounds in 'ubMaps' to each value in `values' that appears in the constraint 1280 // system. Note that both lower/upper bounds share the same operand list 1281 // 'operands'. 1282 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and 1283 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. 1284 // Note that both lower/upper bounds use operands from 'operands'. 1285 // Returns failure for unimplemented cases such as semi-affine expressions or 1286 // expressions with mod/floordiv. 1287 LogicalResult FlatAffineValueConstraints::addSliceBounds( 1288 ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps, 1289 ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) { 1290 assert(values.size() == lbMaps.size()); 1291 assert(lbMaps.size() == ubMaps.size()); 1292 1293 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 1294 unsigned pos; 1295 if (!findId(values[i], &pos)) 1296 continue; 1297 1298 AffineMap lbMap = lbMaps[i]; 1299 AffineMap ubMap = ubMaps[i]; 1300 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 1301 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 1302 1303 // Check if this slice is just an equality along this dimension. 1304 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 1305 ubMap.getNumResults() == 1 && 1306 lbMap.getResult(0) + 1 == ubMap.getResult(0)) { 1307 if (failed(addBound(BoundType::EQ, pos, lbMap, operands))) 1308 return failure(); 1309 continue; 1310 } 1311 1312 // If lower or upper bound maps are null or provide no results, it implies 1313 // that the source loop was not at all sliced, and the entire loop will be a 1314 // part of the slice. 1315 if (lbMap && lbMap.getNumResults() != 0 && ubMap && 1316 ubMap.getNumResults() != 0) { 1317 if (failed(addBound(BoundType::LB, pos, lbMap, operands))) 1318 return failure(); 1319 if (failed(addBound(BoundType::UB, pos, ubMap, operands))) 1320 return failure(); 1321 } else { 1322 auto loop = getForInductionVarOwner(values[i]); 1323 if (failed(this->addAffineForOpDomain(loop))) 1324 return failure(); 1325 } 1326 } 1327 return success(); 1328 } 1329 1330 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const { 1331 unsigned i = 0; 1332 for (const auto &mayBeId : values) { 1333 if (mayBeId.hasValue() && mayBeId.getValue() == val) { 1334 *pos = i; 1335 return true; 1336 } 1337 i++; 1338 } 1339 return false; 1340 } 1341 1342 bool FlatAffineValueConstraints::containsId(Value val) const { 1343 return llvm::any_of(values, [&](const Optional<Value> &mayBeId) { 1344 return mayBeId.hasValue() && mayBeId.getValue() == val; 1345 }); 1346 } 1347 1348 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) { 1349 IntegerPolyhedron::swapId(posA, posB); 1350 1351 if (getIdKindAt(posA) == IdKind::Local && getIdKindAt(posB) == IdKind::Local) 1352 return; 1353 1354 // Treat value of a local identifer as None. 1355 if (getIdKindAt(posA) == IdKind::Local) 1356 values[posB] = None; 1357 else if (getIdKindAt(posB) == IdKind::Local) 1358 values[posA] = None; 1359 else 1360 std::swap(values[posA], values[posB]); 1361 } 1362 1363 void FlatAffineValueConstraints::addBound(BoundType type, Value val, 1364 int64_t value) { 1365 unsigned pos; 1366 if (!findId(val, &pos)) 1367 // This is a pre-condition for this method. 1368 assert(0 && "id not found"); 1369 addBound(type, pos, value); 1370 } 1371 1372 void FlatAffineValueConstraints::printSpace(raw_ostream &os) const { 1373 IntegerPolyhedron::printSpace(os); 1374 os << "("; 1375 for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; i++) { 1376 if (hasValue(i)) 1377 os << "Value "; 1378 else 1379 os << "None "; 1380 } 1381 for (unsigned i = getIdKindOffset(IdKind::Local), 1382 e = getIdKindEnd(IdKind::Local); 1383 i < e; ++i) 1384 os << "Local "; 1385 os << " const)\n"; 1386 } 1387 1388 void FlatAffineValueConstraints::clearAndCopyFrom( 1389 const IntegerRelation &other) { 1390 1391 if (auto *otherValueSet = 1392 dyn_cast<const FlatAffineValueConstraints>(&other)) { 1393 *this = *otherValueSet; 1394 } else { 1395 *static_cast<IntegerRelation *>(this) = other; 1396 values.clear(); 1397 values.resize(getNumDimAndSymbolIds(), None); 1398 } 1399 } 1400 1401 void FlatAffineValueConstraints::fourierMotzkinEliminate( 1402 unsigned pos, bool darkShadow, bool *isResultIntegerExact) { 1403 SmallVector<Optional<Value>, 8> newVals = values; 1404 if (getIdKindAt(pos) != IdKind::Local) 1405 newVals.erase(newVals.begin() + pos); 1406 // Note: Base implementation discards all associated Values. 1407 IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow, 1408 isResultIntegerExact); 1409 values = newVals; 1410 assert(values.size() == getNumDimAndSymbolIds()); 1411 } 1412 1413 void FlatAffineValueConstraints::projectOut(Value val) { 1414 unsigned pos; 1415 bool ret = findId(val, &pos); 1416 assert(ret); 1417 (void)ret; 1418 fourierMotzkinEliminate(pos); 1419 } 1420 1421 LogicalResult FlatAffineValueConstraints::unionBoundingBox( 1422 const FlatAffineValueConstraints &otherCst) { 1423 assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); 1424 assert(otherCst.getMaybeValues() 1425 .slice(0, getNumDimIds()) 1426 .equals(getMaybeValues().slice(0, getNumDimIds())) && 1427 "dim values mismatch"); 1428 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); 1429 assert(getNumLocalIds() == 0 && "local ids not supported yet here"); 1430 1431 // Align `other` to this. 1432 if (!areIdsAligned(*this, otherCst)) { 1433 FlatAffineValueConstraints otherCopy(otherCst); 1434 mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy); 1435 return IntegerPolyhedron::unionBoundingBox(otherCopy); 1436 } 1437 1438 return IntegerPolyhedron::unionBoundingBox(otherCst); 1439 } 1440 1441 /// Compute an explicit representation for local vars. For all systems coming 1442 /// from MLIR integer sets, maps, or expressions where local vars were 1443 /// introduced to model floordivs and mods, this always succeeds. 1444 static LogicalResult computeLocalVars(const FlatAffineValueConstraints &cst, 1445 SmallVectorImpl<AffineExpr> &memo, 1446 MLIRContext *context) { 1447 unsigned numDims = cst.getNumDimIds(); 1448 unsigned numSyms = cst.getNumSymbolIds(); 1449 1450 // Initialize dimensional and symbolic identifiers. 1451 for (unsigned i = 0; i < numDims; i++) 1452 memo[i] = getAffineDimExpr(i, context); 1453 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) 1454 memo[i] = getAffineSymbolExpr(i - numDims, context); 1455 1456 bool changed; 1457 do { 1458 // Each time `changed` is true at the end of this iteration, one or more 1459 // local vars would have been detected as floordivs and set in memo; so the 1460 // number of null entries in memo[...] strictly reduces; so this converges. 1461 changed = false; 1462 for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i) 1463 if (!memo[numDims + numSyms + i] && 1464 detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo)) 1465 changed = true; 1466 } while (changed); 1467 1468 ArrayRef<AffineExpr> localExprs = 1469 ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds()); 1470 return success( 1471 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); 1472 } 1473 1474 void FlatAffineValueConstraints::getIneqAsAffineValueMap( 1475 unsigned pos, unsigned ineqPos, AffineValueMap &vmap, 1476 MLIRContext *context) const { 1477 unsigned numDims = getNumDimIds(); 1478 unsigned numSyms = getNumSymbolIds(); 1479 1480 assert(pos < numDims && "invalid position"); 1481 assert(ineqPos < getNumInequalities() && "invalid inequality position"); 1482 1483 // Get expressions for local vars. 1484 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr()); 1485 if (failed(computeLocalVars(*this, memo, context))) 1486 assert(false && 1487 "one or more local exprs do not have an explicit representation"); 1488 auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds()); 1489 1490 // Compute the AffineExpr lower/upper bound for this inequality. 1491 ArrayRef<int64_t> inequality = getInequality(ineqPos); 1492 SmallVector<int64_t, 8> bound; 1493 bound.reserve(getNumCols() - 1); 1494 // Everything other than the coefficient at `pos`. 1495 bound.append(inequality.begin(), inequality.begin() + pos); 1496 bound.append(inequality.begin() + pos + 1, inequality.end()); 1497 1498 if (inequality[pos] > 0) 1499 // Lower bound. 1500 std::transform(bound.begin(), bound.end(), bound.begin(), 1501 std::negate<int64_t>()); 1502 else 1503 // Upper bound (which is exclusive). 1504 bound.back() += 1; 1505 1506 // Convert to AffineExpr (tree) form. 1507 auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms, 1508 localExprs, context); 1509 1510 // Get the values to bind to this affine expr (all dims and symbols). 1511 SmallVector<Value, 4> operands; 1512 getValues(0, pos, &operands); 1513 SmallVector<Value, 4> trailingOperands; 1514 getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands); 1515 operands.append(trailingOperands.begin(), trailingOperands.end()); 1516 vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands); 1517 } 1518 1519 IntegerSet 1520 FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const { 1521 if (getNumConstraints() == 0) 1522 // Return universal set (always true): 0 == 0. 1523 return IntegerSet::get(getNumDimIds(), getNumSymbolIds(), 1524 getAffineConstantExpr(/*constant=*/0, context), 1525 /*eqFlags=*/true); 1526 1527 // Construct local references. 1528 SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr()); 1529 1530 if (failed(computeLocalVars(*this, memo, context))) { 1531 // Check if the local variables without an explicit representation have 1532 // zero coefficients everywhere. 1533 SmallVector<unsigned> noLocalRepVars; 1534 unsigned numDimsSymbols = getNumDimAndSymbolIds(); 1535 for (unsigned i = numDimsSymbols, e = getNumIds(); i < e; ++i) { 1536 if (!memo[i] && !isColZero(/*pos=*/i)) 1537 noLocalRepVars.push_back(i - numDimsSymbols); 1538 } 1539 if (!noLocalRepVars.empty()) { 1540 LLVM_DEBUG({ 1541 llvm::dbgs() << "local variables at position(s) "; 1542 llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); 1543 llvm::dbgs() << " do not have an explicit representation in:\n"; 1544 this->dump(); 1545 }); 1546 return IntegerSet(); 1547 } 1548 } 1549 1550 ArrayRef<AffineExpr> localExprs = 1551 ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds()); 1552 1553 // Construct the IntegerSet from the equalities/inequalities. 1554 unsigned numDims = getNumDimIds(); 1555 unsigned numSyms = getNumSymbolIds(); 1556 1557 SmallVector<bool, 16> eqFlags(getNumConstraints()); 1558 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); 1559 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); 1560 1561 SmallVector<AffineExpr, 8> exprs; 1562 exprs.reserve(getNumConstraints()); 1563 1564 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 1565 exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms, 1566 localExprs, context)); 1567 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 1568 exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims, 1569 numSyms, localExprs, context)); 1570 return IntegerSet::get(numDims, numSyms, exprs, eqFlags); 1571 } 1572 1573 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, 1574 ValueRange dims, ValueRange syms, 1575 SmallVector<Value> *newSyms) { 1576 assert(operands.size() == map.getNumInputs() && 1577 "expected same number of operands and map inputs"); 1578 MLIRContext *ctx = map.getContext(); 1579 Builder builder(ctx); 1580 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); 1581 unsigned numSymbols = syms.size(); 1582 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); 1583 if (newSyms) { 1584 newSyms->clear(); 1585 newSyms->append(syms.begin(), syms.end()); 1586 } 1587 1588 for (const auto &operand : llvm::enumerate(operands)) { 1589 // Compute replacement dim/sym of operand. 1590 AffineExpr replacement; 1591 auto dimIt = std::find(dims.begin(), dims.end(), operand.value()); 1592 auto symIt = std::find(syms.begin(), syms.end(), operand.value()); 1593 if (dimIt != dims.end()) { 1594 replacement = 1595 builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); 1596 } else if (symIt != syms.end()) { 1597 replacement = 1598 builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); 1599 } else { 1600 // This operand is neither a dimension nor a symbol. Add it as a new 1601 // symbol. 1602 replacement = builder.getAffineSymbolExpr(numSymbols++); 1603 if (newSyms) 1604 newSyms->push_back(operand.value()); 1605 } 1606 // Add to corresponding replacements vector. 1607 if (operand.index() < map.getNumDims()) { 1608 dimReplacements[operand.index()] = replacement; 1609 } else { 1610 symReplacements[operand.index() - map.getNumDims()] = replacement; 1611 } 1612 } 1613 1614 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, 1615 dims.size(), numSymbols); 1616 } 1617 1618 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { 1619 FlatAffineValueConstraints domain = *this; 1620 // Convert all range variables to local variables. 1621 domain.convertToLocal(IdKind::SetDim, getNumDomainDims(), 1622 getNumDomainDims() + getNumRangeDims()); 1623 return domain; 1624 } 1625 1626 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { 1627 FlatAffineValueConstraints range = *this; 1628 // Convert all domain variables to local variables. 1629 range.convertToLocal(IdKind::SetDim, 0, getNumDomainDims()); 1630 return range; 1631 } 1632 1633 void FlatAffineRelation::compose(const FlatAffineRelation &other) { 1634 assert(getNumDomainDims() == other.getNumRangeDims() && 1635 "Domain of this and range of other do not match"); 1636 assert(std::equal(values.begin(), values.begin() + getNumDomainDims(), 1637 other.values.begin() + other.getNumDomainDims()) && 1638 "Domain of this and range of other do not match"); 1639 1640 FlatAffineRelation rel = other; 1641 1642 // Convert `rel` from 1643 // [otherDomain] -> [otherRange] 1644 // to 1645 // [otherDomain] -> [otherRange thisRange] 1646 // and `this` from 1647 // [thisDomain] -> [thisRange] 1648 // to 1649 // [otherDomain thisDomain] -> [thisRange]. 1650 unsigned removeDims = rel.getNumRangeDims(); 1651 insertDomainId(0, rel.getNumDomainDims()); 1652 rel.appendRangeId(getNumRangeDims()); 1653 1654 // Merge symbol and local identifiers. 1655 mergeSymbolIds(rel); 1656 mergeLocalIds(rel); 1657 1658 // Convert `rel` from [otherDomain] -> [otherRange thisRange] to 1659 // [otherDomain] -> [thisRange] by converting first otherRange range ids 1660 // to local ids. 1661 rel.convertToLocal(IdKind::SetDim, rel.getNumDomainDims(), 1662 rel.getNumDomainDims() + removeDims); 1663 // Convert `this` from [otherDomain thisDomain] -> [thisRange] to 1664 // [otherDomain] -> [thisRange] by converting last thisDomain domain ids 1665 // to local ids. 1666 convertToLocal(IdKind::SetDim, getNumDomainDims() - removeDims, 1667 getNumDomainDims()); 1668 1669 auto thisMaybeValues = getMaybeValues(IdKind::SetDim); 1670 auto relMaybeValues = rel.getMaybeValues(IdKind::SetDim); 1671 1672 // Add and match domain of `rel` to domain of `this`. 1673 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1674 if (relMaybeValues[i].hasValue()) 1675 setValue(i, relMaybeValues[i].getValue()); 1676 // Add and match range of `this` to range of `rel`. 1677 for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { 1678 unsigned rangeIdx = rel.getNumDomainDims() + i; 1679 if (thisMaybeValues[rangeIdx].hasValue()) 1680 rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue()); 1681 } 1682 1683 // Append `this` to `rel` and simplify constraints. 1684 rel.append(*this); 1685 rel.removeRedundantLocalVars(); 1686 1687 *this = rel; 1688 } 1689 1690 void FlatAffineRelation::inverse() { 1691 unsigned oldDomain = getNumDomainDims(); 1692 unsigned oldRange = getNumRangeDims(); 1693 // Add new range ids. 1694 appendRangeId(oldDomain); 1695 // Swap new ids with domain. 1696 for (unsigned i = 0; i < oldDomain; ++i) 1697 swapId(i, oldDomain + oldRange + i); 1698 // Remove the swapped domain. 1699 removeIdRange(0, oldDomain); 1700 // Set domain and range as inverse. 1701 numDomainDims = oldRange; 1702 numRangeDims = oldDomain; 1703 } 1704 1705 void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) { 1706 assert(pos <= getNumDomainDims() && 1707 "Id cannot be inserted at invalid position"); 1708 insertDimId(pos, num); 1709 numDomainDims += num; 1710 } 1711 1712 void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) { 1713 assert(pos <= getNumRangeDims() && 1714 "Id cannot be inserted at invalid position"); 1715 insertDimId(getNumDomainDims() + pos, num); 1716 numRangeDims += num; 1717 } 1718 1719 void FlatAffineRelation::appendDomainId(unsigned num) { 1720 insertDimId(getNumDomainDims(), num); 1721 numDomainDims += num; 1722 } 1723 1724 void FlatAffineRelation::appendRangeId(unsigned num) { 1725 insertDimId(getNumDimIds(), num); 1726 numRangeDims += num; 1727 } 1728 1729 void FlatAffineRelation::removeIdRange(IdKind kind, unsigned idStart, 1730 unsigned idLimit) { 1731 assert(idLimit <= getNumIdKind(kind)); 1732 if (idStart >= idLimit) 1733 return; 1734 1735 FlatAffineValueConstraints::removeIdRange(kind, idStart, idLimit); 1736 1737 // If kind is not SetDim, domain and range don't need to be updated. 1738 if (kind != IdKind::SetDim) 1739 return; 1740 1741 // Compute number of domain and range identifiers to remove. This is done by 1742 // intersecting the range of domain/range ids with range of ids to remove. 1743 unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims()); 1744 unsigned intersectDomainRHS = idStart; 1745 unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds()); 1746 unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims()); 1747 1748 if (intersectDomainLHS > intersectDomainRHS) 1749 numDomainDims -= intersectDomainLHS - intersectDomainRHS; 1750 if (intersectRangeLHS > intersectRangeRHS) 1751 numRangeDims -= intersectRangeLHS - intersectRangeRHS; 1752 } 1753 1754 LogicalResult mlir::getRelationFromMap(AffineMap &map, 1755 FlatAffineRelation &rel) { 1756 // Get flattened affine expressions. 1757 std::vector<SmallVector<int64_t, 8>> flatExprs; 1758 FlatAffineValueConstraints localVarCst; 1759 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) 1760 return failure(); 1761 1762 unsigned oldDimNum = localVarCst.getNumDimIds(); 1763 unsigned oldCols = localVarCst.getNumCols(); 1764 unsigned numRangeIds = map.getNumResults(); 1765 unsigned numDomainIds = map.getNumDims(); 1766 1767 // Add range as the new expressions. 1768 localVarCst.appendDimId(numRangeIds); 1769 1770 // Add equalities between source and range. 1771 SmallVector<int64_t, 8> eq(localVarCst.getNumCols()); 1772 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 1773 // Zero fill. 1774 std::fill(eq.begin(), eq.end(), 0); 1775 // Fill equality. 1776 for (unsigned j = 0, f = oldDimNum; j < f; ++j) 1777 eq[j] = flatExprs[i][j]; 1778 for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) 1779 eq[j + numRangeIds] = flatExprs[i][j]; 1780 // Set this dimension to -1 to equate lhs and rhs and add equality. 1781 eq[numDomainIds + i] = -1; 1782 localVarCst.addEquality(eq); 1783 } 1784 1785 // Create relation and return success. 1786 rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst); 1787 return success(); 1788 } 1789 1790 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map, 1791 FlatAffineRelation &rel) { 1792 1793 AffineMap affineMap = map.getAffineMap(); 1794 if (failed(getRelationFromMap(affineMap, rel))) 1795 return failure(); 1796 1797 // Set symbol values for domain dimensions and symbols. 1798 for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) 1799 rel.setValue(i, map.getOperand(i)); 1800 for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e; 1801 ++i) 1802 rel.setValue(i, map.getOperand(i - rel.getNumRangeDims())); 1803 1804 return success(); 1805 } 1806