1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "AffineMapDetail.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Support/LogicalResult.h" 14 #include "mlir/Support/MathExtras.h" 15 #include "llvm/ADT/SmallBitVector.h" 16 #include "llvm/ADT/SmallSet.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 // AffineExprConstantFolder evaluates an affine expression using constant 25 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute 26 // representing the constant value of the affine expression evaluated on 27 // constant 'operandConsts', or nullptr if it can't be folded. 28 class AffineExprConstantFolder { 29 public: 30 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) 31 : numDims(numDims), operandConsts(operandConsts) {} 32 33 /// Attempt to constant fold the specified affine expr, or return null on 34 /// failure. 35 IntegerAttr constantFold(AffineExpr expr) { 36 if (auto result = constantFoldImpl(expr)) 37 return IntegerAttr::get(IndexType::get(expr.getContext()), *result); 38 return nullptr; 39 } 40 41 private: 42 Optional<int64_t> constantFoldImpl(AffineExpr expr) { 43 switch (expr.getKind()) { 44 case AffineExprKind::Add: 45 return constantFoldBinExpr( 46 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); 47 case AffineExprKind::Mul: 48 return constantFoldBinExpr( 49 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); 50 case AffineExprKind::Mod: 51 return constantFoldBinExpr( 52 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); 53 case AffineExprKind::FloorDiv: 54 return constantFoldBinExpr( 55 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); 56 case AffineExprKind::CeilDiv: 57 return constantFoldBinExpr( 58 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); 59 case AffineExprKind::Constant: 60 return expr.cast<AffineConstantExpr>().getValue(); 61 case AffineExprKind::DimId: 62 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()] 63 .dyn_cast_or_null<IntegerAttr>()) 64 return attr.getInt(); 65 return llvm::None; 66 case AffineExprKind::SymbolId: 67 if (auto attr = operandConsts[numDims + 68 expr.cast<AffineSymbolExpr>().getPosition()] 69 .dyn_cast_or_null<IntegerAttr>()) 70 return attr.getInt(); 71 return llvm::None; 72 } 73 llvm_unreachable("Unknown AffineExpr"); 74 } 75 76 // TODO: Change these to operate on APInts too. 77 Optional<int64_t> constantFoldBinExpr(AffineExpr expr, 78 int64_t (*op)(int64_t, int64_t)) { 79 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 80 if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) 81 if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) 82 return op(*lhs, *rhs); 83 return llvm::None; 84 } 85 86 // The number of dimension operands in AffineMap containing this expression. 87 unsigned numDims; 88 // The constant valued operands used to evaluate this AffineExpr. 89 ArrayRef<Attribute> operandConsts; 90 }; 91 92 } // end anonymous namespace 93 94 /// Returns a single constant result affine map. 95 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { 96 return get(/*dimCount=*/0, /*symbolCount=*/0, 97 {getAffineConstantExpr(val, context)}); 98 } 99 100 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most 101 /// minor dimensions. 102 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, 103 MLIRContext *context) { 104 assert(dims >= results && "Dimension mismatch"); 105 auto id = AffineMap::getMultiDimIdentityMap(dims, context); 106 return AffineMap::get(dims, 0, id.getResults().take_back(results), context); 107 } 108 109 bool AffineMap::isMinorIdentity() const { 110 return getNumDims() >= getNumResults() && 111 *this == 112 getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); 113 } 114 115 /// Returns true if this affine map is a minor identity up to broadcasted 116 /// dimensions which are indicated by value 0 in the result. 117 bool AffineMap::isMinorIdentityWithBroadcasting( 118 SmallVectorImpl<unsigned> *broadcastedDims) const { 119 if (broadcastedDims) 120 broadcastedDims->clear(); 121 if (getNumDims() < getNumResults()) 122 return false; 123 unsigned suffixStart = getNumDims() - getNumResults(); 124 for (auto idxAndExpr : llvm::enumerate(getResults())) { 125 unsigned resIdx = idxAndExpr.index(); 126 AffineExpr expr = idxAndExpr.value(); 127 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 128 // Each result may be either a constant 0 (broadcasted dimension). 129 if (constExpr.getValue() != 0) 130 return false; 131 if (broadcastedDims) 132 broadcastedDims->push_back(resIdx); 133 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 134 // Or it may be the input dimension corresponding to this result position. 135 if (dimExpr.getPosition() != suffixStart + resIdx) 136 return false; 137 } else { 138 return false; 139 } 140 } 141 return true; 142 } 143 144 /// Return true if this affine map can be converted to a minor identity with 145 /// broadcast by doing a permute. Return a permutation (there may be 146 /// several) to apply to get to a minor identity with broadcasts. 147 /// Ex: 148 /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with 149 /// perm = [1, 0] and broadcast d2 150 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by 151 /// permutation + broadcast 152 /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) 153 /// with perm = [1, 0, 2] and broadcast d2 154 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra 155 /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with 156 /// perm = [3, 0, 1, 2] 157 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( 158 SmallVectorImpl<unsigned> &permutedDims) const { 159 unsigned projectionStart = 160 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; 161 permutedDims.clear(); 162 SmallVector<unsigned> broadcastDims; 163 permutedDims.resize(getNumResults(), 0); 164 // If there are more results than input dimensions we want the new map to 165 // start with broadcast dimensions in order to be a minor identity with 166 // broadcasting. 167 unsigned leadingBroadcast = 168 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; 169 llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()), 170 false); 171 for (auto idxAndExpr : llvm::enumerate(getResults())) { 172 unsigned resIdx = idxAndExpr.index(); 173 AffineExpr expr = idxAndExpr.value(); 174 // Each result may be either a constant 0 (broadcast dimension) or a 175 // dimension. 176 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 177 if (constExpr.getValue() != 0) 178 return false; 179 broadcastDims.push_back(resIdx); 180 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 181 if (dimExpr.getPosition() < projectionStart) 182 return false; 183 unsigned newPosition = 184 dimExpr.getPosition() - projectionStart + leadingBroadcast; 185 permutedDims[resIdx] = newPosition; 186 dimFound[newPosition] = true; 187 } else { 188 return false; 189 } 190 } 191 // Find a permuation for the broadcast dimension. Since they are broadcasted 192 // any valid permutation is acceptable. We just permute the dim into a slot 193 // without an existing dimension. 194 unsigned pos = 0; 195 for (auto dim : broadcastDims) { 196 while (pos < dimFound.size() && dimFound[pos]) { 197 pos++; 198 } 199 permutedDims[dim] = pos++; 200 } 201 return true; 202 } 203 204 /// Returns an AffineMap representing a permutation. 205 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, 206 MLIRContext *context) { 207 assert(!permutation.empty() && 208 "Cannot create permutation map from empty permutation vector"); 209 SmallVector<AffineExpr, 4> affExprs; 210 for (auto index : permutation) 211 affExprs.push_back(getAffineDimExpr(index, context)); 212 auto m = std::max_element(permutation.begin(), permutation.end()); 213 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); 214 assert(permutationMap.isPermutation() && "Invalid permutation vector"); 215 return permutationMap; 216 } 217 218 template <typename AffineExprContainer> 219 static SmallVector<AffineMap, 4> 220 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) { 221 assert(!exprsList.empty()); 222 assert(!exprsList[0].empty()); 223 auto context = exprsList[0][0].getContext(); 224 int64_t maxDim = -1, maxSym = -1; 225 getMaxDimAndSymbol(exprsList, maxDim, maxSym); 226 SmallVector<AffineMap, 4> maps; 227 maps.reserve(exprsList.size()); 228 for (const auto &exprs : exprsList) 229 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, 230 /*symbolCount=*/maxSym + 1, exprs, context)); 231 return maps; 232 } 233 234 SmallVector<AffineMap, 4> 235 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) { 236 return ::inferFromExprList(exprsList); 237 } 238 239 SmallVector<AffineMap, 4> 240 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) { 241 return ::inferFromExprList(exprsList); 242 } 243 244 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, 245 MLIRContext *context) { 246 SmallVector<AffineExpr, 4> dimExprs; 247 dimExprs.reserve(numDims); 248 for (unsigned i = 0; i < numDims; ++i) 249 dimExprs.push_back(mlir::getAffineDimExpr(i, context)); 250 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); 251 } 252 253 MLIRContext *AffineMap::getContext() const { return map->context; } 254 255 bool AffineMap::isIdentity() const { 256 if (getNumDims() != getNumResults()) 257 return false; 258 ArrayRef<AffineExpr> results = getResults(); 259 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { 260 auto expr = results[i].dyn_cast<AffineDimExpr>(); 261 if (!expr || expr.getPosition() != i) 262 return false; 263 } 264 return true; 265 } 266 267 bool AffineMap::isEmpty() const { 268 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; 269 } 270 271 bool AffineMap::isSingleConstant() const { 272 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>(); 273 } 274 275 bool AffineMap::isConstant() const { 276 return llvm::all_of(getResults(), [](AffineExpr expr) { 277 return expr.isa<AffineConstantExpr>(); 278 }); 279 } 280 281 int64_t AffineMap::getSingleConstantResult() const { 282 assert(isSingleConstant() && "map must have a single constant result"); 283 return getResult(0).cast<AffineConstantExpr>().getValue(); 284 } 285 286 SmallVector<int64_t> AffineMap::getConstantResults() const { 287 assert(isConstant() && "map must have only constant results"); 288 SmallVector<int64_t> result; 289 for (auto expr : getResults()) 290 result.emplace_back(expr.cast<AffineConstantExpr>().getValue()); 291 return result; 292 } 293 294 unsigned AffineMap::getNumDims() const { 295 assert(map && "uninitialized map storage"); 296 return map->numDims; 297 } 298 unsigned AffineMap::getNumSymbols() const { 299 assert(map && "uninitialized map storage"); 300 return map->numSymbols; 301 } 302 unsigned AffineMap::getNumResults() const { 303 assert(map && "uninitialized map storage"); 304 return map->results.size(); 305 } 306 unsigned AffineMap::getNumInputs() const { 307 assert(map && "uninitialized map storage"); 308 return map->numDims + map->numSymbols; 309 } 310 311 ArrayRef<AffineExpr> AffineMap::getResults() const { 312 assert(map && "uninitialized map storage"); 313 return map->results; 314 } 315 AffineExpr AffineMap::getResult(unsigned idx) const { 316 assert(map && "uninitialized map storage"); 317 return map->results[idx]; 318 } 319 320 unsigned AffineMap::getDimPosition(unsigned idx) const { 321 return getResult(idx).cast<AffineDimExpr>().getPosition(); 322 } 323 324 unsigned AffineMap::getPermutedPosition(unsigned input) const { 325 assert(isPermutation() && "invalid permutation request"); 326 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) 327 if (getDimPosition(i) == input) 328 return i; 329 llvm_unreachable("incorrect permutation request"); 330 } 331 332 /// Folds the results of the application of an affine map on the provided 333 /// operands to a constant if possible. Returns false if the folding happens, 334 /// true otherwise. 335 LogicalResult 336 AffineMap::constantFold(ArrayRef<Attribute> operandConstants, 337 SmallVectorImpl<Attribute> &results) const { 338 // Attempt partial folding. 339 SmallVector<int64_t, 2> integers; 340 partialConstantFold(operandConstants, &integers); 341 342 // If all expressions folded to a constant, populate results with attributes 343 // containing those constants. 344 if (integers.empty()) 345 return failure(); 346 347 auto range = llvm::map_range(integers, [this](int64_t i) { 348 return IntegerAttr::get(IndexType::get(getContext()), i); 349 }); 350 results.append(range.begin(), range.end()); 351 return success(); 352 } 353 354 AffineMap 355 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, 356 SmallVectorImpl<int64_t> *results) const { 357 assert(getNumInputs() == operandConstants.size()); 358 359 // Fold each of the result expressions. 360 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); 361 SmallVector<AffineExpr, 4> exprs; 362 exprs.reserve(getNumResults()); 363 364 for (auto expr : getResults()) { 365 auto folded = exprFolder.constantFold(expr); 366 // If did not fold to a constant, keep the original expression, and clear 367 // the integer results vector. 368 if (folded) { 369 exprs.push_back( 370 getAffineConstantExpr(folded.getInt(), folded.getContext())); 371 if (results) 372 results->push_back(folded.getInt()); 373 } else { 374 exprs.push_back(expr); 375 if (results) { 376 results->clear(); 377 results = nullptr; 378 } 379 } 380 } 381 382 return get(getNumDims(), getNumSymbols(), exprs, getContext()); 383 } 384 385 /// Walk all of the AffineExpr's in this mapping. Each node in an expression 386 /// tree is visited in postorder. 387 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const { 388 for (auto expr : getResults()) 389 expr.walk(callback); 390 } 391 392 /// This method substitutes any uses of dimensions and symbols (e.g. 393 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified 394 /// expression mapping. Because this can be used to eliminate dims and 395 /// symbols, the client needs to specify the number of dims and symbols in 396 /// the result. The returned map always has the same number of results. 397 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 398 ArrayRef<AffineExpr> symReplacements, 399 unsigned numResultDims, 400 unsigned numResultSyms) const { 401 SmallVector<AffineExpr, 8> results; 402 results.reserve(getNumResults()); 403 for (auto expr : getResults()) 404 results.push_back( 405 expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); 406 return get(numResultDims, numResultSyms, results, getContext()); 407 } 408 409 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to 410 /// each of the results and return a new AffineMap with the new results and 411 /// with the specified number of dims and symbols. 412 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement, 413 unsigned numResultDims, 414 unsigned numResultSyms) const { 415 SmallVector<AffineExpr, 4> newResults; 416 newResults.reserve(getNumResults()); 417 for (AffineExpr e : getResults()) 418 newResults.push_back(e.replace(expr, replacement)); 419 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 420 } 421 422 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the 423 /// results and return a new AffineMap with the new results and with the 424 /// specified number of dims and symbols. 425 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map, 426 unsigned numResultDims, 427 unsigned numResultSyms) const { 428 SmallVector<AffineExpr, 4> newResults; 429 newResults.reserve(getNumResults()); 430 for (AffineExpr e : getResults()) 431 newResults.push_back(e.replace(map)); 432 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 433 } 434 435 AffineMap 436 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 437 SmallVector<AffineExpr, 4> newResults; 438 newResults.reserve(getNumResults()); 439 for (AffineExpr e : getResults()) 440 newResults.push_back(e.replace(map)); 441 return AffineMap::inferFromExprList(newResults).front(); 442 } 443 444 AffineMap AffineMap::compose(AffineMap map) const { 445 assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); 446 // Prepare `map` by concatenating the symbols and rewriting its exprs. 447 unsigned numDims = map.getNumDims(); 448 unsigned numSymbolsThisMap = getNumSymbols(); 449 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); 450 SmallVector<AffineExpr, 8> newDims(numDims); 451 for (unsigned idx = 0; idx < numDims; ++idx) { 452 newDims[idx] = getAffineDimExpr(idx, getContext()); 453 } 454 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap); 455 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { 456 newSymbols[idx - numSymbolsThisMap] = 457 getAffineSymbolExpr(idx, getContext()); 458 } 459 auto newMap = 460 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); 461 SmallVector<AffineExpr, 8> exprs; 462 exprs.reserve(getResults().size()); 463 for (auto expr : getResults()) 464 exprs.push_back(expr.compose(newMap)); 465 return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); 466 } 467 468 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const { 469 assert(getNumSymbols() == 0 && "Expected symbol-less map"); 470 SmallVector<AffineExpr, 4> exprs; 471 exprs.reserve(values.size()); 472 MLIRContext *ctx = getContext(); 473 for (auto v : values) 474 exprs.push_back(getAffineConstantExpr(v, ctx)); 475 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); 476 SmallVector<int64_t, 4> res; 477 res.reserve(resMap.getNumResults()); 478 for (auto e : resMap.getResults()) 479 res.push_back(e.cast<AffineConstantExpr>().getValue()); 480 return res; 481 } 482 483 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const { 484 if (getNumSymbols() > 0) 485 return false; 486 487 // Having more results than inputs means that results have duplicated dims or 488 // zeros that can't be mapped to input dims. 489 if (getNumResults() > getNumInputs()) 490 return false; 491 492 SmallVector<bool, 8> seen(getNumInputs(), false); 493 // A projected permutation can have, at most, only one instance of each input 494 // dimension in the result expressions. Zeros are allowed as long as the 495 // number of result expressions is lower or equal than the number of input 496 // expressions. 497 for (auto expr : getResults()) { 498 if (auto dim = expr.dyn_cast<AffineDimExpr>()) { 499 if (seen[dim.getPosition()]) 500 return false; 501 seen[dim.getPosition()] = true; 502 } else { 503 auto constExpr = expr.dyn_cast<AffineConstantExpr>(); 504 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0) 505 return false; 506 } 507 } 508 509 // Results are either dims or zeros and zeros can be mapped to input dims. 510 return true; 511 } 512 513 bool AffineMap::isPermutation() const { 514 if (getNumDims() != getNumResults()) 515 return false; 516 return isProjectedPermutation(); 517 } 518 519 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const { 520 SmallVector<AffineExpr, 4> exprs; 521 exprs.reserve(resultPos.size()); 522 for (auto idx : resultPos) 523 exprs.push_back(getResult(idx)); 524 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 525 } 526 527 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { 528 return AffineMap::get(getNumDims(), getNumSymbols(), 529 getResults().slice(start, length), getContext()); 530 } 531 532 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { 533 if (numResults == 0) 534 return AffineMap(); 535 if (numResults > getNumResults()) 536 return *this; 537 return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults))); 538 } 539 540 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { 541 if (numResults == 0) 542 return AffineMap(); 543 if (numResults > getNumResults()) 544 return *this; 545 return getSubMap(llvm::to_vector<4>( 546 llvm::seq<unsigned>(getNumResults() - numResults, getNumResults()))); 547 } 548 549 AffineMap mlir::compressDims(AffineMap map, 550 const llvm::SmallDenseSet<unsigned> &unusedDims) { 551 unsigned numDims = 0; 552 SmallVector<AffineExpr> dimReplacements; 553 dimReplacements.reserve(map.getNumDims()); 554 MLIRContext *context = map.getContext(); 555 for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { 556 if (unusedDims.contains(dim)) 557 dimReplacements.push_back(getAffineConstantExpr(0, context)); 558 else 559 dimReplacements.push_back(getAffineDimExpr(numDims++, context)); 560 } 561 SmallVector<AffineExpr> resultExprs; 562 resultExprs.reserve(map.getNumResults()); 563 for (auto e : map.getResults()) 564 resultExprs.push_back(e.replaceDims(dimReplacements)); 565 return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); 566 } 567 568 AffineMap mlir::compressUnusedDims(AffineMap map) { 569 llvm::SmallDenseSet<unsigned> usedDims; 570 map.walkExprs([&](AffineExpr expr) { 571 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) 572 usedDims.insert(dimExpr.getPosition()); 573 }); 574 llvm::SmallDenseSet<unsigned> unusedDims; 575 for (unsigned d = 0, e = map.getNumDims(); d != e; ++d) 576 if (!usedDims.contains(d)) 577 unusedDims.insert(d); 578 return compressDims(map, unusedDims); 579 } 580 581 static SmallVector<AffineMap> 582 compressUnusedImpl(ArrayRef<AffineMap> maps, 583 llvm::function_ref<AffineMap(AffineMap)> compressionFun) { 584 if (maps.empty()) 585 return SmallVector<AffineMap>(); 586 SmallVector<AffineExpr> allExprs; 587 allExprs.reserve(maps.size() * maps.front().getNumResults()); 588 unsigned numDims = maps.front().getNumDims(), 589 numSymbols = maps.front().getNumSymbols(); 590 for (auto m : maps) { 591 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && 592 "expected maps with same num dims and symbols"); 593 llvm::append_range(allExprs, m.getResults()); 594 } 595 AffineMap unifiedMap = compressionFun( 596 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext())); 597 unsigned unifiedNumDims = unifiedMap.getNumDims(), 598 unifiedNumSymbols = unifiedMap.getNumSymbols(); 599 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults(); 600 SmallVector<AffineMap> res; 601 res.reserve(maps.size()); 602 for (auto m : maps) { 603 res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols, 604 unifiedResults.take_front(m.getNumResults()), 605 m.getContext())); 606 unifiedResults = unifiedResults.drop_front(m.getNumResults()); 607 } 608 return res; 609 } 610 611 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) { 612 return compressUnusedImpl(maps, 613 [](AffineMap m) { return compressUnusedDims(m); }); 614 } 615 616 AffineMap 617 mlir::compressSymbols(AffineMap map, 618 const llvm::SmallDenseSet<unsigned> &unusedSymbols) { 619 unsigned numSymbols = 0; 620 SmallVector<AffineExpr> symReplacements; 621 symReplacements.reserve(map.getNumSymbols()); 622 MLIRContext *context = map.getContext(); 623 for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { 624 if (unusedSymbols.contains(sym)) 625 symReplacements.push_back(getAffineConstantExpr(0, context)); 626 else 627 symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); 628 } 629 SmallVector<AffineExpr> resultExprs; 630 resultExprs.reserve(map.getNumResults()); 631 for (auto e : map.getResults()) 632 resultExprs.push_back(e.replaceSymbols(symReplacements)); 633 return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); 634 } 635 636 AffineMap mlir::compressUnusedSymbols(AffineMap map) { 637 llvm::SmallDenseSet<unsigned> usedSymbols; 638 map.walkExprs([&](AffineExpr expr) { 639 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) 640 usedSymbols.insert(symExpr.getPosition()); 641 }); 642 llvm::SmallDenseSet<unsigned> unusedSymbols; 643 for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d) 644 if (!usedSymbols.contains(d)) 645 unusedSymbols.insert(d); 646 return compressSymbols(map, unusedSymbols); 647 } 648 649 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) { 650 return compressUnusedImpl( 651 maps, [](AffineMap m) { return compressUnusedSymbols(m); }); 652 } 653 654 AffineMap mlir::simplifyAffineMap(AffineMap map) { 655 SmallVector<AffineExpr, 8> exprs; 656 for (auto e : map.getResults()) { 657 exprs.push_back( 658 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); 659 } 660 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, 661 map.getContext()); 662 } 663 664 AffineMap mlir::removeDuplicateExprs(AffineMap map) { 665 auto results = map.getResults(); 666 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end()); 667 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()), 668 uniqueExprs.end()); 669 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs, 670 map.getContext()); 671 } 672 673 AffineMap mlir::inversePermutation(AffineMap map) { 674 if (map.isEmpty()) 675 return map; 676 assert(map.getNumSymbols() == 0 && "expected map without symbols"); 677 SmallVector<AffineExpr, 4> exprs(map.getNumDims()); 678 for (auto en : llvm::enumerate(map.getResults())) { 679 auto expr = en.value(); 680 // Skip non-permutations. 681 if (auto d = expr.dyn_cast<AffineDimExpr>()) { 682 if (exprs[d.getPosition()]) 683 continue; 684 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); 685 } 686 } 687 SmallVector<AffineExpr, 4> seenExprs; 688 seenExprs.reserve(map.getNumDims()); 689 for (auto expr : exprs) 690 if (expr) 691 seenExprs.push_back(expr); 692 if (seenExprs.size() != map.getNumInputs()) 693 return AffineMap(); 694 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); 695 } 696 697 AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) { 698 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true)); 699 MLIRContext *context = map.getContext(); 700 AffineExpr zero = mlir::getAffineConstantExpr(0, context); 701 // Start with all the results as 0. 702 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero); 703 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 704 // Skip zeros from input map. 'exprs' is already initialized to zero. 705 if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) { 706 assert(constExpr.getValue() == 0 && 707 "Unexpected constant in projected permutation"); 708 (void)constExpr; 709 continue; 710 } 711 712 // Reverse each dimension existing in the original map result. 713 exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); 714 } 715 return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); 716 } 717 718 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { 719 unsigned numResults = 0, numDims = 0, numSymbols = 0; 720 for (auto m : maps) 721 numResults += m.getNumResults(); 722 SmallVector<AffineExpr, 8> results; 723 results.reserve(numResults); 724 for (auto m : maps) { 725 for (auto res : m.getResults()) 726 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); 727 728 numSymbols += m.getNumSymbols(); 729 numDims = std::max(m.getNumDims(), numDims); 730 } 731 return AffineMap::get(numDims, numSymbols, results, 732 maps.front().getContext()); 733 } 734 735 AffineMap 736 mlir::getProjectedMap(AffineMap map, 737 const llvm::SmallDenseSet<unsigned> &unusedDims) { 738 return compressUnusedSymbols(compressDims(map, unusedDims)); 739 } 740 741 //===----------------------------------------------------------------------===// 742 // MutableAffineMap. 743 //===----------------------------------------------------------------------===// 744 745 MutableAffineMap::MutableAffineMap(AffineMap map) 746 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), 747 context(map.getContext()) { 748 for (auto result : map.getResults()) 749 results.push_back(result); 750 } 751 752 void MutableAffineMap::reset(AffineMap map) { 753 results.clear(); 754 numDims = map.getNumDims(); 755 numSymbols = map.getNumSymbols(); 756 context = map.getContext(); 757 for (auto result : map.getResults()) 758 results.push_back(result); 759 } 760 761 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 762 if (results[idx].isMultipleOf(factor)) 763 return true; 764 765 // TODO: use simplifyAffineExpr and FlatAffineConstraints to 766 // complete this (for a more powerful analysis). 767 return false; 768 } 769 770 // Simplifies the result affine expressions of this map. The expressions have to 771 // be pure for the simplification implemented. 772 void MutableAffineMap::simplify() { 773 // Simplify each of the results if possible. 774 // TODO: functional-style map 775 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 776 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 777 } 778 } 779 780 AffineMap MutableAffineMap::getAffineMap() const { 781 return AffineMap::get(numDims, numSymbols, results, context); 782 } 783