1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "AffineMapDetail.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Support/LogicalResult.h" 14 #include "mlir/Support/MathExtras.h" 15 #include "llvm/ADT/SmallBitVector.h" 16 #include "llvm/ADT/SmallSet.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 // AffineExprConstantFolder evaluates an affine expression using constant 25 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute 26 // representing the constant value of the affine expression evaluated on 27 // constant 'operandConsts', or nullptr if it can't be folded. 28 class AffineExprConstantFolder { 29 public: 30 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) 31 : numDims(numDims), operandConsts(operandConsts) {} 32 33 /// Attempt to constant fold the specified affine expr, or return null on 34 /// failure. 35 IntegerAttr constantFold(AffineExpr expr) { 36 if (auto result = constantFoldImpl(expr)) 37 return IntegerAttr::get(IndexType::get(expr.getContext()), *result); 38 return nullptr; 39 } 40 41 private: 42 Optional<int64_t> constantFoldImpl(AffineExpr expr) { 43 switch (expr.getKind()) { 44 case AffineExprKind::Add: 45 return constantFoldBinExpr( 46 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); 47 case AffineExprKind::Mul: 48 return constantFoldBinExpr( 49 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); 50 case AffineExprKind::Mod: 51 return constantFoldBinExpr( 52 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); 53 case AffineExprKind::FloorDiv: 54 return constantFoldBinExpr( 55 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); 56 case AffineExprKind::CeilDiv: 57 return constantFoldBinExpr( 58 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); 59 case AffineExprKind::Constant: 60 return expr.cast<AffineConstantExpr>().getValue(); 61 case AffineExprKind::DimId: 62 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()] 63 .dyn_cast_or_null<IntegerAttr>()) 64 return attr.getInt(); 65 return llvm::None; 66 case AffineExprKind::SymbolId: 67 if (auto attr = operandConsts[numDims + 68 expr.cast<AffineSymbolExpr>().getPosition()] 69 .dyn_cast_or_null<IntegerAttr>()) 70 return attr.getInt(); 71 return llvm::None; 72 } 73 llvm_unreachable("Unknown AffineExpr"); 74 } 75 76 // TODO: Change these to operate on APInts too. 77 Optional<int64_t> constantFoldBinExpr(AffineExpr expr, 78 int64_t (*op)(int64_t, int64_t)) { 79 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 80 if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) 81 if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) 82 return op(*lhs, *rhs); 83 return llvm::None; 84 } 85 86 // The number of dimension operands in AffineMap containing this expression. 87 unsigned numDims; 88 // The constant valued operands used to evaluate this AffineExpr. 89 ArrayRef<Attribute> operandConsts; 90 }; 91 92 } // end anonymous namespace 93 94 /// Returns a single constant result affine map. 95 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { 96 return get(/*dimCount=*/0, /*symbolCount=*/0, 97 {getAffineConstantExpr(val, context)}); 98 } 99 100 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most 101 /// minor dimensions. 102 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, 103 MLIRContext *context) { 104 assert(dims >= results && "Dimension mismatch"); 105 auto id = AffineMap::getMultiDimIdentityMap(dims, context); 106 return AffineMap::get(dims, 0, id.getResults().take_back(results), context); 107 } 108 109 bool AffineMap::isMinorIdentity() const { 110 return getNumDims() >= getNumResults() && 111 *this == 112 getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); 113 } 114 115 /// Returns true if this affine map is a minor identity up to broadcasted 116 /// dimensions which are indicated by value 0 in the result. 117 bool AffineMap::isMinorIdentityWithBroadcasting( 118 SmallVectorImpl<unsigned> *broadcastedDims) const { 119 if (broadcastedDims) 120 broadcastedDims->clear(); 121 if (getNumDims() < getNumResults()) 122 return false; 123 unsigned suffixStart = getNumDims() - getNumResults(); 124 for (auto idxAndExpr : llvm::enumerate(getResults())) { 125 unsigned resIdx = idxAndExpr.index(); 126 AffineExpr expr = idxAndExpr.value(); 127 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 128 // Each result may be either a constant 0 (broadcasted dimension). 129 if (constExpr.getValue() != 0) 130 return false; 131 if (broadcastedDims) 132 broadcastedDims->push_back(resIdx); 133 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 134 // Or it may be the input dimension corresponding to this result position. 135 if (dimExpr.getPosition() != suffixStart + resIdx) 136 return false; 137 } else { 138 return false; 139 } 140 } 141 return true; 142 } 143 144 /// Return true if this affine map can be converted to a minor identity with 145 /// broadcast by doing a permute. Return a permutation (there may be 146 /// several) to apply to get to a minor identity with broadcasts. 147 /// Ex: 148 /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with 149 /// perm = [1, 0] and broadcast d2 150 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by 151 /// permutation + broadcast 152 /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) 153 /// with perm = [1, 0, 2] and broadcast d2 154 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra 155 /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with 156 /// perm = [3, 0, 1, 2] 157 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( 158 SmallVectorImpl<unsigned> &permutedDims) const { 159 unsigned projectionStart = 160 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; 161 permutedDims.clear(); 162 SmallVector<unsigned> broadcastDims; 163 permutedDims.resize(getNumResults(), 0); 164 // If there are more results than input dimensions we want the new map to 165 // start with broadcast dimensions in order to be a minor identity with 166 // broadcasting. 167 unsigned leadingBroadcast = 168 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; 169 llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()), 170 false); 171 for (auto idxAndExpr : llvm::enumerate(getResults())) { 172 unsigned resIdx = idxAndExpr.index(); 173 AffineExpr expr = idxAndExpr.value(); 174 // Each result may be either a constant 0 (broadcast dimension) or a 175 // dimension. 176 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 177 if (constExpr.getValue() != 0) 178 return false; 179 broadcastDims.push_back(resIdx); 180 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 181 if (dimExpr.getPosition() < projectionStart) 182 return false; 183 unsigned newPosition = 184 dimExpr.getPosition() - projectionStart + leadingBroadcast; 185 permutedDims[resIdx] = newPosition; 186 dimFound[newPosition] = true; 187 } else { 188 return false; 189 } 190 } 191 // Find a permuation for the broadcast dimension. Since they are broadcasted 192 // any valid permutation is acceptable. We just permute the dim into a slot 193 // without an existing dimension. 194 unsigned pos = 0; 195 for (auto dim : broadcastDims) { 196 while (pos < dimFound.size() && dimFound[pos]) { 197 pos++; 198 } 199 permutedDims[dim] = pos++; 200 } 201 return true; 202 } 203 204 /// Returns an AffineMap representing a permutation. 205 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, 206 MLIRContext *context) { 207 assert(!permutation.empty() && 208 "Cannot create permutation map from empty permutation vector"); 209 SmallVector<AffineExpr, 4> affExprs; 210 for (auto index : permutation) 211 affExprs.push_back(getAffineDimExpr(index, context)); 212 auto m = std::max_element(permutation.begin(), permutation.end()); 213 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); 214 assert(permutationMap.isPermutation() && "Invalid permutation vector"); 215 return permutationMap; 216 } 217 218 template <typename AffineExprContainer> 219 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList, 220 int64_t &maxDim, int64_t &maxSym) { 221 for (const auto &exprs : exprsList) { 222 for (auto expr : exprs) { 223 expr.walk([&maxDim, &maxSym](AffineExpr e) { 224 if (auto d = e.dyn_cast<AffineDimExpr>()) 225 maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition())); 226 if (auto s = e.dyn_cast<AffineSymbolExpr>()) 227 maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition())); 228 }); 229 } 230 } 231 } 232 233 template <typename AffineExprContainer> 234 static SmallVector<AffineMap, 4> 235 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) { 236 assert(!exprsList.empty()); 237 assert(!exprsList[0].empty()); 238 auto context = exprsList[0][0].getContext(); 239 int64_t maxDim = -1, maxSym = -1; 240 getMaxDimAndSymbol(exprsList, maxDim, maxSym); 241 SmallVector<AffineMap, 4> maps; 242 maps.reserve(exprsList.size()); 243 for (const auto &exprs : exprsList) 244 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, 245 /*symbolCount=*/maxSym + 1, exprs, context)); 246 return maps; 247 } 248 249 SmallVector<AffineMap, 4> 250 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) { 251 return ::inferFromExprList(exprsList); 252 } 253 254 SmallVector<AffineMap, 4> 255 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) { 256 return ::inferFromExprList(exprsList); 257 } 258 259 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, 260 MLIRContext *context) { 261 SmallVector<AffineExpr, 4> dimExprs; 262 dimExprs.reserve(numDims); 263 for (unsigned i = 0; i < numDims; ++i) 264 dimExprs.push_back(mlir::getAffineDimExpr(i, context)); 265 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); 266 } 267 268 MLIRContext *AffineMap::getContext() const { return map->context; } 269 270 bool AffineMap::isIdentity() const { 271 if (getNumDims() != getNumResults()) 272 return false; 273 ArrayRef<AffineExpr> results = getResults(); 274 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { 275 auto expr = results[i].dyn_cast<AffineDimExpr>(); 276 if (!expr || expr.getPosition() != i) 277 return false; 278 } 279 return true; 280 } 281 282 bool AffineMap::isEmpty() const { 283 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; 284 } 285 286 bool AffineMap::isSingleConstant() const { 287 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>(); 288 } 289 290 bool AffineMap::isConstant() const { 291 return llvm::all_of(getResults(), [](AffineExpr expr) { 292 return expr.isa<AffineConstantExpr>(); 293 }); 294 } 295 296 int64_t AffineMap::getSingleConstantResult() const { 297 assert(isSingleConstant() && "map must have a single constant result"); 298 return getResult(0).cast<AffineConstantExpr>().getValue(); 299 } 300 301 SmallVector<int64_t> AffineMap::getConstantResults() const { 302 assert(isConstant() && "map must have only constant results"); 303 SmallVector<int64_t> result; 304 for (auto expr : getResults()) 305 result.emplace_back(expr.cast<AffineConstantExpr>().getValue()); 306 return result; 307 } 308 309 unsigned AffineMap::getNumDims() const { 310 assert(map && "uninitialized map storage"); 311 return map->numDims; 312 } 313 unsigned AffineMap::getNumSymbols() const { 314 assert(map && "uninitialized map storage"); 315 return map->numSymbols; 316 } 317 unsigned AffineMap::getNumResults() const { 318 assert(map && "uninitialized map storage"); 319 return map->results.size(); 320 } 321 unsigned AffineMap::getNumInputs() const { 322 assert(map && "uninitialized map storage"); 323 return map->numDims + map->numSymbols; 324 } 325 326 ArrayRef<AffineExpr> AffineMap::getResults() const { 327 assert(map && "uninitialized map storage"); 328 return map->results; 329 } 330 AffineExpr AffineMap::getResult(unsigned idx) const { 331 assert(map && "uninitialized map storage"); 332 return map->results[idx]; 333 } 334 335 unsigned AffineMap::getDimPosition(unsigned idx) const { 336 return getResult(idx).cast<AffineDimExpr>().getPosition(); 337 } 338 339 unsigned AffineMap::getPermutedPosition(unsigned input) const { 340 assert(isPermutation() && "invalid permutation request"); 341 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) 342 if (getDimPosition(i) == input) 343 return i; 344 llvm_unreachable("incorrect permutation request"); 345 } 346 347 /// Folds the results of the application of an affine map on the provided 348 /// operands to a constant if possible. Returns false if the folding happens, 349 /// true otherwise. 350 LogicalResult 351 AffineMap::constantFold(ArrayRef<Attribute> operandConstants, 352 SmallVectorImpl<Attribute> &results) const { 353 // Attempt partial folding. 354 SmallVector<int64_t, 2> integers; 355 partialConstantFold(operandConstants, &integers); 356 357 // If all expressions folded to a constant, populate results with attributes 358 // containing those constants. 359 if (integers.empty()) 360 return failure(); 361 362 auto range = llvm::map_range(integers, [this](int64_t i) { 363 return IntegerAttr::get(IndexType::get(getContext()), i); 364 }); 365 results.append(range.begin(), range.end()); 366 return success(); 367 } 368 369 AffineMap 370 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, 371 SmallVectorImpl<int64_t> *results) const { 372 assert(getNumInputs() == operandConstants.size()); 373 374 // Fold each of the result expressions. 375 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); 376 SmallVector<AffineExpr, 4> exprs; 377 exprs.reserve(getNumResults()); 378 379 for (auto expr : getResults()) { 380 auto folded = exprFolder.constantFold(expr); 381 // If did not fold to a constant, keep the original expression, and clear 382 // the integer results vector. 383 if (folded) { 384 exprs.push_back( 385 getAffineConstantExpr(folded.getInt(), folded.getContext())); 386 if (results) 387 results->push_back(folded.getInt()); 388 } else { 389 exprs.push_back(expr); 390 if (results) { 391 results->clear(); 392 results = nullptr; 393 } 394 } 395 } 396 397 return get(getNumDims(), getNumSymbols(), exprs, getContext()); 398 } 399 400 /// Walk all of the AffineExpr's in this mapping. Each node in an expression 401 /// tree is visited in postorder. 402 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const { 403 for (auto expr : getResults()) 404 expr.walk(callback); 405 } 406 407 /// This method substitutes any uses of dimensions and symbols (e.g. 408 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified 409 /// expression mapping. Because this can be used to eliminate dims and 410 /// symbols, the client needs to specify the number of dims and symbols in 411 /// the result. The returned map always has the same number of results. 412 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 413 ArrayRef<AffineExpr> symReplacements, 414 unsigned numResultDims, 415 unsigned numResultSyms) const { 416 SmallVector<AffineExpr, 8> results; 417 results.reserve(getNumResults()); 418 for (auto expr : getResults()) 419 results.push_back( 420 expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); 421 return get(numResultDims, numResultSyms, results, getContext()); 422 } 423 424 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to 425 /// each of the results and return a new AffineMap with the new results and 426 /// with the specified number of dims and symbols. 427 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement, 428 unsigned numResultDims, 429 unsigned numResultSyms) const { 430 SmallVector<AffineExpr, 4> newResults; 431 newResults.reserve(getNumResults()); 432 for (AffineExpr e : getResults()) 433 newResults.push_back(e.replace(expr, replacement)); 434 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 435 } 436 437 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the 438 /// results and return a new AffineMap with the new results and with the 439 /// specified number of dims and symbols. 440 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map, 441 unsigned numResultDims, 442 unsigned numResultSyms) const { 443 SmallVector<AffineExpr, 4> newResults; 444 newResults.reserve(getNumResults()); 445 for (AffineExpr e : getResults()) 446 newResults.push_back(e.replace(map)); 447 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 448 } 449 450 AffineMap 451 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 452 SmallVector<AffineExpr, 4> newResults; 453 newResults.reserve(getNumResults()); 454 for (AffineExpr e : getResults()) 455 newResults.push_back(e.replace(map)); 456 return AffineMap::inferFromExprList(newResults).front(); 457 } 458 459 AffineMap AffineMap::compose(AffineMap map) const { 460 assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); 461 // Prepare `map` by concatenating the symbols and rewriting its exprs. 462 unsigned numDims = map.getNumDims(); 463 unsigned numSymbolsThisMap = getNumSymbols(); 464 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); 465 SmallVector<AffineExpr, 8> newDims(numDims); 466 for (unsigned idx = 0; idx < numDims; ++idx) { 467 newDims[idx] = getAffineDimExpr(idx, getContext()); 468 } 469 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap); 470 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { 471 newSymbols[idx - numSymbolsThisMap] = 472 getAffineSymbolExpr(idx, getContext()); 473 } 474 auto newMap = 475 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); 476 SmallVector<AffineExpr, 8> exprs; 477 exprs.reserve(getResults().size()); 478 for (auto expr : getResults()) 479 exprs.push_back(expr.compose(newMap)); 480 return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); 481 } 482 483 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const { 484 assert(getNumSymbols() == 0 && "Expected symbol-less map"); 485 SmallVector<AffineExpr, 4> exprs; 486 exprs.reserve(values.size()); 487 MLIRContext *ctx = getContext(); 488 for (auto v : values) 489 exprs.push_back(getAffineConstantExpr(v, ctx)); 490 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); 491 SmallVector<int64_t, 4> res; 492 res.reserve(resMap.getNumResults()); 493 for (auto e : resMap.getResults()) 494 res.push_back(e.cast<AffineConstantExpr>().getValue()); 495 return res; 496 } 497 498 bool AffineMap::isProjectedPermutation() const { 499 if (getNumSymbols() > 0) 500 return false; 501 SmallVector<bool, 8> seen(getNumInputs(), false); 502 for (auto expr : getResults()) { 503 if (auto dim = expr.dyn_cast<AffineDimExpr>()) { 504 if (seen[dim.getPosition()]) 505 return false; 506 seen[dim.getPosition()] = true; 507 continue; 508 } 509 return false; 510 } 511 return true; 512 } 513 514 bool AffineMap::isPermutation() const { 515 if (getNumDims() != getNumResults()) 516 return false; 517 return isProjectedPermutation(); 518 } 519 520 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const { 521 SmallVector<AffineExpr, 4> exprs; 522 exprs.reserve(resultPos.size()); 523 for (auto idx : resultPos) 524 exprs.push_back(getResult(idx)); 525 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 526 } 527 528 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { 529 return AffineMap::get(getNumDims(), getNumSymbols(), 530 getResults().slice(start, length), getContext()); 531 } 532 533 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { 534 if (numResults == 0) 535 return AffineMap(); 536 if (numResults > getNumResults()) 537 return *this; 538 return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults))); 539 } 540 541 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { 542 if (numResults == 0) 543 return AffineMap(); 544 if (numResults > getNumResults()) 545 return *this; 546 return getSubMap(llvm::to_vector<4>( 547 llvm::seq<unsigned>(getNumResults() - numResults, getNumResults()))); 548 } 549 550 AffineMap mlir::compressDims(AffineMap map, 551 const llvm::SmallDenseSet<unsigned> &unusedDims) { 552 unsigned numDims = 0; 553 SmallVector<AffineExpr> dimReplacements; 554 dimReplacements.reserve(map.getNumDims()); 555 MLIRContext *context = map.getContext(); 556 for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { 557 if (unusedDims.contains(dim)) 558 dimReplacements.push_back(getAffineConstantExpr(0, context)); 559 else 560 dimReplacements.push_back(getAffineDimExpr(numDims++, context)); 561 } 562 SmallVector<AffineExpr> resultExprs; 563 resultExprs.reserve(map.getNumResults()); 564 for (auto e : map.getResults()) 565 resultExprs.push_back(e.replaceDims(dimReplacements)); 566 return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); 567 } 568 569 AffineMap mlir::compressUnusedDims(AffineMap map) { 570 llvm::SmallDenseSet<unsigned> usedDims; 571 map.walkExprs([&](AffineExpr expr) { 572 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) 573 usedDims.insert(dimExpr.getPosition()); 574 }); 575 llvm::SmallDenseSet<unsigned> unusedDims; 576 for (unsigned d = 0, e = map.getNumDims(); d != e; ++d) 577 if (!usedDims.contains(d)) 578 unusedDims.insert(d); 579 return compressDims(map, unusedDims); 580 } 581 582 static SmallVector<AffineMap> 583 compressUnusedImpl(ArrayRef<AffineMap> maps, 584 llvm::function_ref<AffineMap(AffineMap)> compressionFun) { 585 if (maps.empty()) 586 return SmallVector<AffineMap>(); 587 SmallVector<AffineExpr> allExprs; 588 allExprs.reserve(maps.size() * maps.front().getNumResults()); 589 unsigned numDims = maps.front().getNumDims(), 590 numSymbols = maps.front().getNumSymbols(); 591 for (auto m : maps) { 592 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && 593 "expected maps with same num dims and symbols"); 594 llvm::append_range(allExprs, m.getResults()); 595 } 596 AffineMap unifiedMap = compressionFun( 597 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext())); 598 unsigned unifiedNumDims = unifiedMap.getNumDims(), 599 unifiedNumSymbols = unifiedMap.getNumSymbols(); 600 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults(); 601 SmallVector<AffineMap> res; 602 res.reserve(maps.size()); 603 for (auto m : maps) { 604 res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols, 605 unifiedResults.take_front(m.getNumResults()), 606 m.getContext())); 607 unifiedResults = unifiedResults.drop_front(m.getNumResults()); 608 } 609 return res; 610 } 611 612 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) { 613 return compressUnusedImpl(maps, 614 [](AffineMap m) { return compressUnusedDims(m); }); 615 } 616 617 AffineMap 618 mlir::compressSymbols(AffineMap map, 619 const llvm::SmallDenseSet<unsigned> &unusedSymbols) { 620 unsigned numSymbols = 0; 621 SmallVector<AffineExpr> symReplacements; 622 symReplacements.reserve(map.getNumSymbols()); 623 MLIRContext *context = map.getContext(); 624 for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { 625 if (unusedSymbols.contains(sym)) 626 symReplacements.push_back(getAffineConstantExpr(0, context)); 627 else 628 symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); 629 } 630 SmallVector<AffineExpr> resultExprs; 631 resultExprs.reserve(map.getNumResults()); 632 for (auto e : map.getResults()) 633 resultExprs.push_back(e.replaceSymbols(symReplacements)); 634 return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); 635 } 636 637 AffineMap mlir::compressUnusedSymbols(AffineMap map) { 638 llvm::SmallDenseSet<unsigned> usedSymbols; 639 map.walkExprs([&](AffineExpr expr) { 640 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) 641 usedSymbols.insert(symExpr.getPosition()); 642 }); 643 llvm::SmallDenseSet<unsigned> unusedSymbols; 644 for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d) 645 if (!usedSymbols.contains(d)) 646 unusedSymbols.insert(d); 647 return compressSymbols(map, unusedSymbols); 648 } 649 650 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) { 651 return compressUnusedImpl( 652 maps, [](AffineMap m) { return compressUnusedSymbols(m); }); 653 } 654 655 AffineMap mlir::simplifyAffineMap(AffineMap map) { 656 SmallVector<AffineExpr, 8> exprs; 657 for (auto e : map.getResults()) { 658 exprs.push_back( 659 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); 660 } 661 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, 662 map.getContext()); 663 } 664 665 AffineMap mlir::removeDuplicateExprs(AffineMap map) { 666 auto results = map.getResults(); 667 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end()); 668 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()), 669 uniqueExprs.end()); 670 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs, 671 map.getContext()); 672 } 673 674 AffineMap mlir::inversePermutation(AffineMap map) { 675 if (map.isEmpty()) 676 return map; 677 assert(map.getNumSymbols() == 0 && "expected map without symbols"); 678 SmallVector<AffineExpr, 4> exprs(map.getNumDims()); 679 for (auto en : llvm::enumerate(map.getResults())) { 680 auto expr = en.value(); 681 // Skip non-permutations. 682 if (auto d = expr.dyn_cast<AffineDimExpr>()) { 683 if (exprs[d.getPosition()]) 684 continue; 685 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); 686 } 687 } 688 SmallVector<AffineExpr, 4> seenExprs; 689 seenExprs.reserve(map.getNumDims()); 690 for (auto expr : exprs) 691 if (expr) 692 seenExprs.push_back(expr); 693 if (seenExprs.size() != map.getNumInputs()) 694 return AffineMap(); 695 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); 696 } 697 698 AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) { 699 assert(map.isProjectedPermutation()); 700 MLIRContext *context = map.getContext(); 701 AffineExpr zero = mlir::getAffineConstantExpr(0, context); 702 // Start with all the results as 0. 703 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero); 704 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 705 // Reverse each dimension existing in the oringal map result. 706 exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); 707 } 708 return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); 709 } 710 711 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { 712 unsigned numResults = 0, numDims = 0, numSymbols = 0; 713 for (auto m : maps) 714 numResults += m.getNumResults(); 715 SmallVector<AffineExpr, 8> results; 716 results.reserve(numResults); 717 for (auto m : maps) { 718 for (auto res : m.getResults()) 719 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); 720 721 numSymbols += m.getNumSymbols(); 722 numDims = std::max(m.getNumDims(), numDims); 723 } 724 return AffineMap::get(numDims, numSymbols, results, 725 maps.front().getContext()); 726 } 727 728 AffineMap 729 mlir::getProjectedMap(AffineMap map, 730 const llvm::SmallDenseSet<unsigned> &unusedDims) { 731 return compressUnusedSymbols(compressDims(map, unusedDims)); 732 } 733 734 //===----------------------------------------------------------------------===// 735 // MutableAffineMap. 736 //===----------------------------------------------------------------------===// 737 738 MutableAffineMap::MutableAffineMap(AffineMap map) 739 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), 740 context(map.getContext()) { 741 for (auto result : map.getResults()) 742 results.push_back(result); 743 } 744 745 void MutableAffineMap::reset(AffineMap map) { 746 results.clear(); 747 numDims = map.getNumDims(); 748 numSymbols = map.getNumSymbols(); 749 context = map.getContext(); 750 for (auto result : map.getResults()) 751 results.push_back(result); 752 } 753 754 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 755 if (results[idx].isMultipleOf(factor)) 756 return true; 757 758 // TODO: use simplifyAffineExpr and FlatAffineConstraints to 759 // complete this (for a more powerful analysis). 760 return false; 761 } 762 763 // Simplifies the result affine expressions of this map. The expressions have to 764 // be pure for the simplification implemented. 765 void MutableAffineMap::simplify() { 766 // Simplify each of the results if possible. 767 // TODO: functional-style map 768 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 769 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 770 } 771 } 772 773 AffineMap MutableAffineMap::getAffineMap() const { 774 return AffineMap::get(numDims, numSymbols, results, context); 775 } 776