1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "AffineMapDetail.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Support/LogicalResult.h" 14 #include "mlir/Support/MathExtras.h" 15 #include "llvm/ADT/SmallBitVector.h" 16 #include "llvm/ADT/SmallSet.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 // AffineExprConstantFolder evaluates an affine expression using constant 25 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute 26 // representing the constant value of the affine expression evaluated on 27 // constant 'operandConsts', or nullptr if it can't be folded. 28 class AffineExprConstantFolder { 29 public: 30 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) 31 : numDims(numDims), operandConsts(operandConsts) {} 32 33 /// Attempt to constant fold the specified affine expr, or return null on 34 /// failure. 35 IntegerAttr constantFold(AffineExpr expr) { 36 if (auto result = constantFoldImpl(expr)) 37 return IntegerAttr::get(IndexType::get(expr.getContext()), *result); 38 return nullptr; 39 } 40 41 private: 42 Optional<int64_t> constantFoldImpl(AffineExpr expr) { 43 switch (expr.getKind()) { 44 case AffineExprKind::Add: 45 return constantFoldBinExpr( 46 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); 47 case AffineExprKind::Mul: 48 return constantFoldBinExpr( 49 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); 50 case AffineExprKind::Mod: 51 return constantFoldBinExpr( 52 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); 53 case AffineExprKind::FloorDiv: 54 return constantFoldBinExpr( 55 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); 56 case AffineExprKind::CeilDiv: 57 return constantFoldBinExpr( 58 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); 59 case AffineExprKind::Constant: 60 return expr.cast<AffineConstantExpr>().getValue(); 61 case AffineExprKind::DimId: 62 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()] 63 .dyn_cast_or_null<IntegerAttr>()) 64 return attr.getInt(); 65 return llvm::None; 66 case AffineExprKind::SymbolId: 67 if (auto attr = operandConsts[numDims + 68 expr.cast<AffineSymbolExpr>().getPosition()] 69 .dyn_cast_or_null<IntegerAttr>()) 70 return attr.getInt(); 71 return llvm::None; 72 } 73 llvm_unreachable("Unknown AffineExpr"); 74 } 75 76 // TODO: Change these to operate on APInts too. 77 Optional<int64_t> constantFoldBinExpr(AffineExpr expr, 78 int64_t (*op)(int64_t, int64_t)) { 79 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 80 if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) 81 if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) 82 return op(*lhs, *rhs); 83 return llvm::None; 84 } 85 86 // The number of dimension operands in AffineMap containing this expression. 87 unsigned numDims; 88 // The constant valued operands used to evaluate this AffineExpr. 89 ArrayRef<Attribute> operandConsts; 90 }; 91 92 } // namespace 93 94 /// Returns a single constant result affine map. 95 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { 96 return get(/*dimCount=*/0, /*symbolCount=*/0, 97 {getAffineConstantExpr(val, context)}); 98 } 99 100 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most 101 /// minor dimensions. 102 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, 103 MLIRContext *context) { 104 assert(dims >= results && "Dimension mismatch"); 105 auto id = AffineMap::getMultiDimIdentityMap(dims, context); 106 return AffineMap::get(dims, 0, id.getResults().take_back(results), context); 107 } 108 109 bool AffineMap::isMinorIdentity() const { 110 return getNumDims() >= getNumResults() && 111 *this == 112 getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); 113 } 114 115 /// Returns true if this affine map is a minor identity up to broadcasted 116 /// dimensions which are indicated by value 0 in the result. 117 bool AffineMap::isMinorIdentityWithBroadcasting( 118 SmallVectorImpl<unsigned> *broadcastedDims) const { 119 if (broadcastedDims) 120 broadcastedDims->clear(); 121 if (getNumDims() < getNumResults()) 122 return false; 123 unsigned suffixStart = getNumDims() - getNumResults(); 124 for (const auto &idxAndExpr : llvm::enumerate(getResults())) { 125 unsigned resIdx = idxAndExpr.index(); 126 AffineExpr expr = idxAndExpr.value(); 127 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 128 // Each result may be either a constant 0 (broadcasted dimension). 129 if (constExpr.getValue() != 0) 130 return false; 131 if (broadcastedDims) 132 broadcastedDims->push_back(resIdx); 133 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 134 // Or it may be the input dimension corresponding to this result position. 135 if (dimExpr.getPosition() != suffixStart + resIdx) 136 return false; 137 } else { 138 return false; 139 } 140 } 141 return true; 142 } 143 144 /// Return true if this affine map can be converted to a minor identity with 145 /// broadcast by doing a permute. Return a permutation (there may be 146 /// several) to apply to get to a minor identity with broadcasts. 147 /// Ex: 148 /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with 149 /// perm = [1, 0] and broadcast d2 150 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by 151 /// permutation + broadcast 152 /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) 153 /// with perm = [1, 0, 2] and broadcast d2 154 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra 155 /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with 156 /// perm = [3, 0, 1, 2] 157 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( 158 SmallVectorImpl<unsigned> &permutedDims) const { 159 unsigned projectionStart = 160 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; 161 permutedDims.clear(); 162 SmallVector<unsigned> broadcastDims; 163 permutedDims.resize(getNumResults(), 0); 164 // If there are more results than input dimensions we want the new map to 165 // start with broadcast dimensions in order to be a minor identity with 166 // broadcasting. 167 unsigned leadingBroadcast = 168 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; 169 llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()), 170 false); 171 for (const auto &idxAndExpr : llvm::enumerate(getResults())) { 172 unsigned resIdx = idxAndExpr.index(); 173 AffineExpr expr = idxAndExpr.value(); 174 // Each result may be either a constant 0 (broadcast dimension) or a 175 // dimension. 176 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 177 if (constExpr.getValue() != 0) 178 return false; 179 broadcastDims.push_back(resIdx); 180 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 181 if (dimExpr.getPosition() < projectionStart) 182 return false; 183 unsigned newPosition = 184 dimExpr.getPosition() - projectionStart + leadingBroadcast; 185 permutedDims[resIdx] = newPosition; 186 dimFound[newPosition] = true; 187 } else { 188 return false; 189 } 190 } 191 // Find a permuation for the broadcast dimension. Since they are broadcasted 192 // any valid permutation is acceptable. We just permute the dim into a slot 193 // without an existing dimension. 194 unsigned pos = 0; 195 for (auto dim : broadcastDims) { 196 while (pos < dimFound.size() && dimFound[pos]) { 197 pos++; 198 } 199 permutedDims[dim] = pos++; 200 } 201 return true; 202 } 203 204 /// Returns an AffineMap representing a permutation. 205 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, 206 MLIRContext *context) { 207 assert(!permutation.empty() && 208 "Cannot create permutation map from empty permutation vector"); 209 SmallVector<AffineExpr, 4> affExprs; 210 for (auto index : permutation) 211 affExprs.push_back(getAffineDimExpr(index, context)); 212 const auto *m = std::max_element(permutation.begin(), permutation.end()); 213 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); 214 assert(permutationMap.isPermutation() && "Invalid permutation vector"); 215 return permutationMap; 216 } 217 218 template <typename AffineExprContainer> 219 static SmallVector<AffineMap, 4> 220 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) { 221 assert(!exprsList.empty()); 222 assert(!exprsList[0].empty()); 223 auto context = exprsList[0][0].getContext(); 224 int64_t maxDim = -1, maxSym = -1; 225 getMaxDimAndSymbol(exprsList, maxDim, maxSym); 226 SmallVector<AffineMap, 4> maps; 227 maps.reserve(exprsList.size()); 228 for (const auto &exprs : exprsList) 229 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, 230 /*symbolCount=*/maxSym + 1, exprs, context)); 231 return maps; 232 } 233 234 SmallVector<AffineMap, 4> 235 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) { 236 return ::inferFromExprList(exprsList); 237 } 238 239 SmallVector<AffineMap, 4> 240 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) { 241 return ::inferFromExprList(exprsList); 242 } 243 244 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, 245 MLIRContext *context) { 246 SmallVector<AffineExpr, 4> dimExprs; 247 dimExprs.reserve(numDims); 248 for (unsigned i = 0; i < numDims; ++i) 249 dimExprs.push_back(mlir::getAffineDimExpr(i, context)); 250 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); 251 } 252 253 MLIRContext *AffineMap::getContext() const { return map->context; } 254 255 bool AffineMap::isIdentity() const { 256 if (getNumDims() != getNumResults()) 257 return false; 258 ArrayRef<AffineExpr> results = getResults(); 259 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { 260 auto expr = results[i].dyn_cast<AffineDimExpr>(); 261 if (!expr || expr.getPosition() != i) 262 return false; 263 } 264 return true; 265 } 266 267 bool AffineMap::isEmpty() const { 268 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; 269 } 270 271 bool AffineMap::isSingleConstant() const { 272 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>(); 273 } 274 275 bool AffineMap::isConstant() const { 276 return llvm::all_of(getResults(), [](AffineExpr expr) { 277 return expr.isa<AffineConstantExpr>(); 278 }); 279 } 280 281 int64_t AffineMap::getSingleConstantResult() const { 282 assert(isSingleConstant() && "map must have a single constant result"); 283 return getResult(0).cast<AffineConstantExpr>().getValue(); 284 } 285 286 SmallVector<int64_t> AffineMap::getConstantResults() const { 287 assert(isConstant() && "map must have only constant results"); 288 SmallVector<int64_t> result; 289 for (auto expr : getResults()) 290 result.emplace_back(expr.cast<AffineConstantExpr>().getValue()); 291 return result; 292 } 293 294 unsigned AffineMap::getNumDims() const { 295 assert(map && "uninitialized map storage"); 296 return map->numDims; 297 } 298 unsigned AffineMap::getNumSymbols() const { 299 assert(map && "uninitialized map storage"); 300 return map->numSymbols; 301 } 302 unsigned AffineMap::getNumResults() const { return getResults().size(); } 303 unsigned AffineMap::getNumInputs() const { 304 assert(map && "uninitialized map storage"); 305 return map->numDims + map->numSymbols; 306 } 307 ArrayRef<AffineExpr> AffineMap::getResults() const { 308 assert(map && "uninitialized map storage"); 309 return map->results(); 310 } 311 AffineExpr AffineMap::getResult(unsigned idx) const { 312 return getResults()[idx]; 313 } 314 315 unsigned AffineMap::getDimPosition(unsigned idx) const { 316 return getResult(idx).cast<AffineDimExpr>().getPosition(); 317 } 318 319 unsigned AffineMap::getPermutedPosition(unsigned input) const { 320 assert(isPermutation() && "invalid permutation request"); 321 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) 322 if (getDimPosition(i) == input) 323 return i; 324 llvm_unreachable("incorrect permutation request"); 325 } 326 327 /// Folds the results of the application of an affine map on the provided 328 /// operands to a constant if possible. Returns false if the folding happens, 329 /// true otherwise. 330 LogicalResult 331 AffineMap::constantFold(ArrayRef<Attribute> operandConstants, 332 SmallVectorImpl<Attribute> &results) const { 333 // Attempt partial folding. 334 SmallVector<int64_t, 2> integers; 335 partialConstantFold(operandConstants, &integers); 336 337 // If all expressions folded to a constant, populate results with attributes 338 // containing those constants. 339 if (integers.empty()) 340 return failure(); 341 342 auto range = llvm::map_range(integers, [this](int64_t i) { 343 return IntegerAttr::get(IndexType::get(getContext()), i); 344 }); 345 results.append(range.begin(), range.end()); 346 return success(); 347 } 348 349 AffineMap 350 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, 351 SmallVectorImpl<int64_t> *results) const { 352 assert(getNumInputs() == operandConstants.size()); 353 354 // Fold each of the result expressions. 355 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); 356 SmallVector<AffineExpr, 4> exprs; 357 exprs.reserve(getNumResults()); 358 359 for (auto expr : getResults()) { 360 auto folded = exprFolder.constantFold(expr); 361 // If did not fold to a constant, keep the original expression, and clear 362 // the integer results vector. 363 if (folded) { 364 exprs.push_back( 365 getAffineConstantExpr(folded.getInt(), folded.getContext())); 366 if (results) 367 results->push_back(folded.getInt()); 368 } else { 369 exprs.push_back(expr); 370 if (results) { 371 results->clear(); 372 results = nullptr; 373 } 374 } 375 } 376 377 return get(getNumDims(), getNumSymbols(), exprs, getContext()); 378 } 379 380 /// Walk all of the AffineExpr's in this mapping. Each node in an expression 381 /// tree is visited in postorder. 382 void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const { 383 for (auto expr : getResults()) 384 expr.walk(callback); 385 } 386 387 /// This method substitutes any uses of dimensions and symbols (e.g. 388 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified 389 /// expression mapping. Because this can be used to eliminate dims and 390 /// symbols, the client needs to specify the number of dims and symbols in 391 /// the result. The returned map always has the same number of results. 392 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 393 ArrayRef<AffineExpr> symReplacements, 394 unsigned numResultDims, 395 unsigned numResultSyms) const { 396 SmallVector<AffineExpr, 8> results; 397 results.reserve(getNumResults()); 398 for (auto expr : getResults()) 399 results.push_back( 400 expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); 401 return get(numResultDims, numResultSyms, results, getContext()); 402 } 403 404 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to 405 /// each of the results and return a new AffineMap with the new results and 406 /// with the specified number of dims and symbols. 407 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement, 408 unsigned numResultDims, 409 unsigned numResultSyms) const { 410 SmallVector<AffineExpr, 4> newResults; 411 newResults.reserve(getNumResults()); 412 for (AffineExpr e : getResults()) 413 newResults.push_back(e.replace(expr, replacement)); 414 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 415 } 416 417 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the 418 /// results and return a new AffineMap with the new results and with the 419 /// specified number of dims and symbols. 420 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map, 421 unsigned numResultDims, 422 unsigned numResultSyms) const { 423 SmallVector<AffineExpr, 4> newResults; 424 newResults.reserve(getNumResults()); 425 for (AffineExpr e : getResults()) 426 newResults.push_back(e.replace(map)); 427 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 428 } 429 430 AffineMap 431 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 432 SmallVector<AffineExpr, 4> newResults; 433 newResults.reserve(getNumResults()); 434 for (AffineExpr e : getResults()) 435 newResults.push_back(e.replace(map)); 436 return AffineMap::inferFromExprList(newResults).front(); 437 } 438 439 AffineMap AffineMap::compose(AffineMap map) const { 440 assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); 441 // Prepare `map` by concatenating the symbols and rewriting its exprs. 442 unsigned numDims = map.getNumDims(); 443 unsigned numSymbolsThisMap = getNumSymbols(); 444 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); 445 SmallVector<AffineExpr, 8> newDims(numDims); 446 for (unsigned idx = 0; idx < numDims; ++idx) { 447 newDims[idx] = getAffineDimExpr(idx, getContext()); 448 } 449 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap); 450 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { 451 newSymbols[idx - numSymbolsThisMap] = 452 getAffineSymbolExpr(idx, getContext()); 453 } 454 auto newMap = 455 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); 456 SmallVector<AffineExpr, 8> exprs; 457 exprs.reserve(getResults().size()); 458 for (auto expr : getResults()) 459 exprs.push_back(expr.compose(newMap)); 460 return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); 461 } 462 463 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const { 464 assert(getNumSymbols() == 0 && "Expected symbol-less map"); 465 SmallVector<AffineExpr, 4> exprs; 466 exprs.reserve(values.size()); 467 MLIRContext *ctx = getContext(); 468 for (auto v : values) 469 exprs.push_back(getAffineConstantExpr(v, ctx)); 470 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); 471 SmallVector<int64_t, 4> res; 472 res.reserve(resMap.getNumResults()); 473 for (auto e : resMap.getResults()) 474 res.push_back(e.cast<AffineConstantExpr>().getValue()); 475 return res; 476 } 477 478 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const { 479 if (getNumSymbols() > 0) 480 return false; 481 482 // Having more results than inputs means that results have duplicated dims or 483 // zeros that can't be mapped to input dims. 484 if (getNumResults() > getNumInputs()) 485 return false; 486 487 SmallVector<bool, 8> seen(getNumInputs(), false); 488 // A projected permutation can have, at most, only one instance of each input 489 // dimension in the result expressions. Zeros are allowed as long as the 490 // number of result expressions is lower or equal than the number of input 491 // expressions. 492 for (auto expr : getResults()) { 493 if (auto dim = expr.dyn_cast<AffineDimExpr>()) { 494 if (seen[dim.getPosition()]) 495 return false; 496 seen[dim.getPosition()] = true; 497 } else { 498 auto constExpr = expr.dyn_cast<AffineConstantExpr>(); 499 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0) 500 return false; 501 } 502 } 503 504 // Results are either dims or zeros and zeros can be mapped to input dims. 505 return true; 506 } 507 508 bool AffineMap::isPermutation() const { 509 if (getNumDims() != getNumResults()) 510 return false; 511 return isProjectedPermutation(); 512 } 513 514 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const { 515 SmallVector<AffineExpr, 4> exprs; 516 exprs.reserve(resultPos.size()); 517 for (auto idx : resultPos) 518 exprs.push_back(getResult(idx)); 519 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 520 } 521 522 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { 523 return AffineMap::get(getNumDims(), getNumSymbols(), 524 getResults().slice(start, length), getContext()); 525 } 526 527 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { 528 if (numResults == 0) 529 return AffineMap(); 530 if (numResults > getNumResults()) 531 return *this; 532 return getSliceMap(0, numResults); 533 } 534 535 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { 536 if (numResults == 0) 537 return AffineMap(); 538 if (numResults > getNumResults()) 539 return *this; 540 return getSliceMap(getNumResults() - numResults, numResults); 541 } 542 543 AffineMap mlir::compressDims(AffineMap map, 544 const llvm::SmallBitVector &unusedDims) { 545 unsigned numDims = 0; 546 SmallVector<AffineExpr> dimReplacements; 547 dimReplacements.reserve(map.getNumDims()); 548 MLIRContext *context = map.getContext(); 549 for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { 550 if (unusedDims.test(dim)) 551 dimReplacements.push_back(getAffineConstantExpr(0, context)); 552 else 553 dimReplacements.push_back(getAffineDimExpr(numDims++, context)); 554 } 555 SmallVector<AffineExpr> resultExprs; 556 resultExprs.reserve(map.getNumResults()); 557 for (auto e : map.getResults()) 558 resultExprs.push_back(e.replaceDims(dimReplacements)); 559 return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); 560 } 561 562 AffineMap mlir::compressUnusedDims(AffineMap map) { 563 llvm::SmallBitVector unusedDims(map.getNumDims(), true); 564 map.walkExprs([&](AffineExpr expr) { 565 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) 566 unusedDims.reset(dimExpr.getPosition()); 567 }); 568 return compressDims(map, unusedDims); 569 } 570 571 static SmallVector<AffineMap> 572 compressUnusedImpl(ArrayRef<AffineMap> maps, 573 llvm::function_ref<AffineMap(AffineMap)> compressionFun) { 574 if (maps.empty()) 575 return SmallVector<AffineMap>(); 576 SmallVector<AffineExpr> allExprs; 577 allExprs.reserve(maps.size() * maps.front().getNumResults()); 578 unsigned numDims = maps.front().getNumDims(), 579 numSymbols = maps.front().getNumSymbols(); 580 for (auto m : maps) { 581 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && 582 "expected maps with same num dims and symbols"); 583 llvm::append_range(allExprs, m.getResults()); 584 } 585 AffineMap unifiedMap = compressionFun( 586 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext())); 587 unsigned unifiedNumDims = unifiedMap.getNumDims(), 588 unifiedNumSymbols = unifiedMap.getNumSymbols(); 589 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults(); 590 SmallVector<AffineMap> res; 591 res.reserve(maps.size()); 592 for (auto m : maps) { 593 res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols, 594 unifiedResults.take_front(m.getNumResults()), 595 m.getContext())); 596 unifiedResults = unifiedResults.drop_front(m.getNumResults()); 597 } 598 return res; 599 } 600 601 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) { 602 return compressUnusedImpl(maps, 603 [](AffineMap m) { return compressUnusedDims(m); }); 604 } 605 606 AffineMap mlir::compressSymbols(AffineMap map, 607 const llvm::SmallBitVector &unusedSymbols) { 608 unsigned numSymbols = 0; 609 SmallVector<AffineExpr> symReplacements; 610 symReplacements.reserve(map.getNumSymbols()); 611 MLIRContext *context = map.getContext(); 612 for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { 613 if (unusedSymbols.test(sym)) 614 symReplacements.push_back(getAffineConstantExpr(0, context)); 615 else 616 symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); 617 } 618 SmallVector<AffineExpr> resultExprs; 619 resultExprs.reserve(map.getNumResults()); 620 for (auto e : map.getResults()) 621 resultExprs.push_back(e.replaceSymbols(symReplacements)); 622 return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); 623 } 624 625 AffineMap mlir::compressUnusedSymbols(AffineMap map) { 626 llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true); 627 map.walkExprs([&](AffineExpr expr) { 628 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) 629 unusedSymbols.reset(symExpr.getPosition()); 630 }); 631 return compressSymbols(map, unusedSymbols); 632 } 633 634 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) { 635 return compressUnusedImpl( 636 maps, [](AffineMap m) { return compressUnusedSymbols(m); }); 637 } 638 639 AffineMap mlir::simplifyAffineMap(AffineMap map) { 640 SmallVector<AffineExpr, 8> exprs; 641 for (auto e : map.getResults()) { 642 exprs.push_back( 643 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); 644 } 645 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, 646 map.getContext()); 647 } 648 649 AffineMap mlir::removeDuplicateExprs(AffineMap map) { 650 auto results = map.getResults(); 651 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end()); 652 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()), 653 uniqueExprs.end()); 654 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs, 655 map.getContext()); 656 } 657 658 AffineMap mlir::inversePermutation(AffineMap map) { 659 if (map.isEmpty()) 660 return map; 661 assert(map.getNumSymbols() == 0 && "expected map without symbols"); 662 SmallVector<AffineExpr, 4> exprs(map.getNumDims()); 663 for (const auto &en : llvm::enumerate(map.getResults())) { 664 auto expr = en.value(); 665 // Skip non-permutations. 666 if (auto d = expr.dyn_cast<AffineDimExpr>()) { 667 if (exprs[d.getPosition()]) 668 continue; 669 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); 670 } 671 } 672 SmallVector<AffineExpr, 4> seenExprs; 673 seenExprs.reserve(map.getNumDims()); 674 for (auto expr : exprs) 675 if (expr) 676 seenExprs.push_back(expr); 677 if (seenExprs.size() != map.getNumInputs()) 678 return AffineMap(); 679 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); 680 } 681 682 AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { 683 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true)); 684 MLIRContext *context = map.getContext(); 685 AffineExpr zero = mlir::getAffineConstantExpr(0, context); 686 // Start with all the results as 0. 687 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero); 688 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 689 // Skip zeros from input map. 'exprs' is already initialized to zero. 690 if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) { 691 assert(constExpr.getValue() == 0 && 692 "Unexpected constant in projected permutation"); 693 (void)constExpr; 694 continue; 695 } 696 697 // Reverse each dimension existing in the original map result. 698 exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); 699 } 700 return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); 701 } 702 703 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { 704 unsigned numResults = 0, numDims = 0, numSymbols = 0; 705 for (auto m : maps) 706 numResults += m.getNumResults(); 707 SmallVector<AffineExpr, 8> results; 708 results.reserve(numResults); 709 for (auto m : maps) { 710 for (auto res : m.getResults()) 711 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); 712 713 numSymbols += m.getNumSymbols(); 714 numDims = std::max(m.getNumDims(), numDims); 715 } 716 return AffineMap::get(numDims, numSymbols, results, 717 maps.front().getContext()); 718 } 719 720 AffineMap mlir::getProjectedMap(AffineMap map, 721 const llvm::SmallBitVector &unusedDims) { 722 return compressUnusedSymbols(compressDims(map, unusedDims)); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // MutableAffineMap. 727 //===----------------------------------------------------------------------===// 728 729 MutableAffineMap::MutableAffineMap(AffineMap map) 730 : results(map.getResults().begin(), map.getResults().end()), 731 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), 732 context(map.getContext()) {} 733 734 void MutableAffineMap::reset(AffineMap map) { 735 results.clear(); 736 numDims = map.getNumDims(); 737 numSymbols = map.getNumSymbols(); 738 context = map.getContext(); 739 llvm::append_range(results, map.getResults()); 740 } 741 742 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 743 if (results[idx].isMultipleOf(factor)) 744 return true; 745 746 // TODO: use simplifyAffineExpr and FlatAffineConstraints to 747 // complete this (for a more powerful analysis). 748 return false; 749 } 750 751 // Simplifies the result affine expressions of this map. The expressions have to 752 // be pure for the simplification implemented. 753 void MutableAffineMap::simplify() { 754 // Simplify each of the results if possible. 755 // TODO: functional-style map 756 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 757 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 758 } 759 } 760 761 AffineMap MutableAffineMap::getAffineMap() const { 762 return AffineMap::get(numDims, numSymbols, results, context); 763 } 764