1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "AffineMapDetail.h" 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/StandardTypes.h" 13 #include "mlir/Support/LogicalResult.h" 14 #include "mlir/Support/MathExtras.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 20 namespace { 21 22 // AffineExprConstantFolder evaluates an affine expression using constant 23 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute 24 // representing the constant value of the affine expression evaluated on 25 // constant 'operandConsts', or nullptr if it can't be folded. 26 class AffineExprConstantFolder { 27 public: 28 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) 29 : numDims(numDims), operandConsts(operandConsts) {} 30 31 /// Attempt to constant fold the specified affine expr, or return null on 32 /// failure. 33 IntegerAttr constantFold(AffineExpr expr) { 34 if (auto result = constantFoldImpl(expr)) 35 return IntegerAttr::get(IndexType::get(expr.getContext()), *result); 36 return nullptr; 37 } 38 39 private: 40 Optional<int64_t> constantFoldImpl(AffineExpr expr) { 41 switch (expr.getKind()) { 42 case AffineExprKind::Add: 43 return constantFoldBinExpr( 44 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); 45 case AffineExprKind::Mul: 46 return constantFoldBinExpr( 47 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); 48 case AffineExprKind::Mod: 49 return constantFoldBinExpr( 50 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); 51 case AffineExprKind::FloorDiv: 52 return constantFoldBinExpr( 53 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); 54 case AffineExprKind::CeilDiv: 55 return constantFoldBinExpr( 56 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); 57 case AffineExprKind::Constant: 58 return expr.cast<AffineConstantExpr>().getValue(); 59 case AffineExprKind::DimId: 60 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()] 61 .dyn_cast_or_null<IntegerAttr>()) 62 return attr.getInt(); 63 return llvm::None; 64 case AffineExprKind::SymbolId: 65 if (auto attr = operandConsts[numDims + 66 expr.cast<AffineSymbolExpr>().getPosition()] 67 .dyn_cast_or_null<IntegerAttr>()) 68 return attr.getInt(); 69 return llvm::None; 70 } 71 llvm_unreachable("Unknown AffineExpr"); 72 } 73 74 // TODO: Change these to operate on APInts too. 75 Optional<int64_t> constantFoldBinExpr(AffineExpr expr, 76 int64_t (*op)(int64_t, int64_t)) { 77 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 78 if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) 79 if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) 80 return op(*lhs, *rhs); 81 return llvm::None; 82 } 83 84 // The number of dimension operands in AffineMap containing this expression. 85 unsigned numDims; 86 // The constant valued operands used to evaluate this AffineExpr. 87 ArrayRef<Attribute> operandConsts; 88 }; 89 90 } // end anonymous namespace 91 92 /// Returns a single constant result affine map. 93 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { 94 return get(/*dimCount=*/0, /*symbolCount=*/0, 95 {getAffineConstantExpr(val, context)}); 96 } 97 98 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most 99 /// minor dimensions. 100 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, 101 MLIRContext *context) { 102 assert(dims >= results && "Dimension mismatch"); 103 auto id = AffineMap::getMultiDimIdentityMap(dims, context); 104 return AffineMap::get(dims, 0, id.getResults().take_back(results), context); 105 } 106 107 bool AffineMap::isMinorIdentity() const { 108 return *this == 109 getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); 110 } 111 112 /// Returns an AffineMap representing a permutation. 113 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, 114 MLIRContext *context) { 115 assert(!permutation.empty() && 116 "Cannot create permutation map from empty permutation vector"); 117 SmallVector<AffineExpr, 4> affExprs; 118 for (auto index : permutation) 119 affExprs.push_back(getAffineDimExpr(index, context)); 120 auto m = std::max_element(permutation.begin(), permutation.end()); 121 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); 122 assert(permutationMap.isPermutation() && "Invalid permutation vector"); 123 return permutationMap; 124 } 125 126 template <typename AffineExprContainer> 127 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList, 128 int64_t &maxDim, int64_t &maxSym) { 129 for (const auto &exprs : exprsList) { 130 for (auto expr : exprs) { 131 expr.walk([&maxDim, &maxSym](AffineExpr e) { 132 if (auto d = e.dyn_cast<AffineDimExpr>()) 133 maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition())); 134 if (auto s = e.dyn_cast<AffineSymbolExpr>()) 135 maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition())); 136 }); 137 } 138 } 139 } 140 141 template <typename AffineExprContainer> 142 static SmallVector<AffineMap, 4> 143 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) { 144 assert(!exprsList.empty()); 145 assert(!exprsList[0].empty()); 146 auto context = exprsList[0][0].getContext(); 147 int64_t maxDim = -1, maxSym = -1; 148 getMaxDimAndSymbol(exprsList, maxDim, maxSym); 149 SmallVector<AffineMap, 4> maps; 150 maps.reserve(exprsList.size()); 151 for (const auto &exprs : exprsList) 152 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, 153 /*symbolCount=*/maxSym + 1, exprs, context)); 154 return maps; 155 } 156 157 SmallVector<AffineMap, 4> 158 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) { 159 return ::inferFromExprList(exprsList); 160 } 161 162 SmallVector<AffineMap, 4> 163 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) { 164 return ::inferFromExprList(exprsList); 165 } 166 167 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, 168 MLIRContext *context) { 169 SmallVector<AffineExpr, 4> dimExprs; 170 dimExprs.reserve(numDims); 171 for (unsigned i = 0; i < numDims; ++i) 172 dimExprs.push_back(mlir::getAffineDimExpr(i, context)); 173 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); 174 } 175 176 MLIRContext *AffineMap::getContext() const { return map->context; } 177 178 bool AffineMap::isIdentity() const { 179 if (getNumDims() != getNumResults()) 180 return false; 181 ArrayRef<AffineExpr> results = getResults(); 182 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { 183 auto expr = results[i].dyn_cast<AffineDimExpr>(); 184 if (!expr || expr.getPosition() != i) 185 return false; 186 } 187 return true; 188 } 189 190 bool AffineMap::isEmpty() const { 191 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; 192 } 193 194 bool AffineMap::isSingleConstant() const { 195 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>(); 196 } 197 198 int64_t AffineMap::getSingleConstantResult() const { 199 assert(isSingleConstant() && "map must have a single constant result"); 200 return getResult(0).cast<AffineConstantExpr>().getValue(); 201 } 202 203 unsigned AffineMap::getNumDims() const { 204 assert(map && "uninitialized map storage"); 205 return map->numDims; 206 } 207 unsigned AffineMap::getNumSymbols() const { 208 assert(map && "uninitialized map storage"); 209 return map->numSymbols; 210 } 211 unsigned AffineMap::getNumResults() const { 212 assert(map && "uninitialized map storage"); 213 return map->results.size(); 214 } 215 unsigned AffineMap::getNumInputs() const { 216 assert(map && "uninitialized map storage"); 217 return map->numDims + map->numSymbols; 218 } 219 220 ArrayRef<AffineExpr> AffineMap::getResults() const { 221 assert(map && "uninitialized map storage"); 222 return map->results; 223 } 224 AffineExpr AffineMap::getResult(unsigned idx) const { 225 assert(map && "uninitialized map storage"); 226 return map->results[idx]; 227 } 228 229 /// Folds the results of the application of an affine map on the provided 230 /// operands to a constant if possible. Returns false if the folding happens, 231 /// true otherwise. 232 LogicalResult 233 AffineMap::constantFold(ArrayRef<Attribute> operandConstants, 234 SmallVectorImpl<Attribute> &results) const { 235 // Attempt partial folding. 236 SmallVector<int64_t, 2> integers; 237 partialConstantFold(operandConstants, &integers); 238 239 // If all expressions folded to a constant, populate results with attributes 240 // containing those constants. 241 if (integers.empty()) 242 return failure(); 243 244 auto range = llvm::map_range(integers, [this](int64_t i) { 245 return IntegerAttr::get(IndexType::get(getContext()), i); 246 }); 247 results.append(range.begin(), range.end()); 248 return success(); 249 } 250 251 AffineMap 252 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, 253 SmallVectorImpl<int64_t> *results) const { 254 assert(getNumInputs() == operandConstants.size()); 255 256 // Fold each of the result expressions. 257 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); 258 SmallVector<AffineExpr, 4> exprs; 259 exprs.reserve(getNumResults()); 260 261 for (auto expr : getResults()) { 262 auto folded = exprFolder.constantFold(expr); 263 // If did not fold to a constant, keep the original expression, and clear 264 // the integer results vector. 265 if (folded) { 266 exprs.push_back( 267 getAffineConstantExpr(folded.getInt(), folded.getContext())); 268 if (results) 269 results->push_back(folded.getInt()); 270 } else { 271 exprs.push_back(expr); 272 if (results) { 273 results->clear(); 274 results = nullptr; 275 } 276 } 277 } 278 279 return get(getNumDims(), getNumSymbols(), exprs, getContext()); 280 } 281 282 /// Walk all of the AffineExpr's in this mapping. Each node in an expression 283 /// tree is visited in postorder. 284 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const { 285 for (auto expr : getResults()) 286 expr.walk(callback); 287 } 288 289 /// This method substitutes any uses of dimensions and symbols (e.g. 290 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified 291 /// expression mapping. Because this can be used to eliminate dims and 292 /// symbols, the client needs to specify the number of dims and symbols in 293 /// the result. The returned map always has the same number of results. 294 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 295 ArrayRef<AffineExpr> symReplacements, 296 unsigned numResultDims, 297 unsigned numResultSyms) const { 298 SmallVector<AffineExpr, 8> results; 299 results.reserve(getNumResults()); 300 for (auto expr : getResults()) 301 results.push_back( 302 expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); 303 304 return get(numResultDims, numResultSyms, results, getContext()); 305 } 306 307 AffineMap AffineMap::compose(AffineMap map) { 308 assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); 309 // Prepare `map` by concatenating the symbols and rewriting its exprs. 310 unsigned numDims = map.getNumDims(); 311 unsigned numSymbolsThisMap = getNumSymbols(); 312 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); 313 SmallVector<AffineExpr, 8> newDims(numDims); 314 for (unsigned idx = 0; idx < numDims; ++idx) { 315 newDims[idx] = getAffineDimExpr(idx, getContext()); 316 } 317 SmallVector<AffineExpr, 8> newSymbols(numSymbols); 318 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { 319 newSymbols[idx - numSymbolsThisMap] = 320 getAffineSymbolExpr(idx, getContext()); 321 } 322 auto newMap = 323 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); 324 SmallVector<AffineExpr, 8> exprs; 325 exprs.reserve(getResults().size()); 326 for (auto expr : getResults()) 327 exprs.push_back(expr.compose(newMap)); 328 return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); 329 } 330 331 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) { 332 assert(getNumSymbols() == 0 && "Expected symbol-less map"); 333 SmallVector<AffineExpr, 4> exprs; 334 exprs.reserve(values.size()); 335 MLIRContext *ctx = getContext(); 336 for (auto v : values) 337 exprs.push_back(getAffineConstantExpr(v, ctx)); 338 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); 339 SmallVector<int64_t, 4> res; 340 res.reserve(resMap.getNumResults()); 341 for (auto e : resMap.getResults()) 342 res.push_back(e.cast<AffineConstantExpr>().getValue()); 343 return res; 344 } 345 346 bool AffineMap::isProjectedPermutation() { 347 if (getNumSymbols() > 0) 348 return false; 349 SmallVector<bool, 8> seen(getNumInputs(), false); 350 for (auto expr : getResults()) { 351 if (auto dim = expr.dyn_cast<AffineDimExpr>()) { 352 if (seen[dim.getPosition()]) 353 return false; 354 seen[dim.getPosition()] = true; 355 continue; 356 } 357 return false; 358 } 359 return true; 360 } 361 362 bool AffineMap::isPermutation() { 363 if (getNumDims() != getNumResults()) 364 return false; 365 return isProjectedPermutation(); 366 } 367 368 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) { 369 SmallVector<AffineExpr, 4> exprs; 370 exprs.reserve(resultPos.size()); 371 for (auto idx : resultPos) 372 exprs.push_back(getResult(idx)); 373 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 374 } 375 376 AffineMap AffineMap::getMajorSubMap(unsigned numResults) { 377 if (numResults == 0) 378 return AffineMap(); 379 if (numResults > getNumResults()) 380 return *this; 381 return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults))); 382 } 383 384 AffineMap AffineMap::getMinorSubMap(unsigned numResults) { 385 if (numResults == 0) 386 return AffineMap(); 387 if (numResults > getNumResults()) 388 return *this; 389 return getSubMap(llvm::to_vector<4>( 390 llvm::seq<unsigned>(getNumResults() - numResults, getNumResults()))); 391 } 392 393 AffineMap mlir::simplifyAffineMap(AffineMap map) { 394 SmallVector<AffineExpr, 8> exprs; 395 for (auto e : map.getResults()) { 396 exprs.push_back( 397 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); 398 } 399 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, 400 map.getContext()); 401 } 402 403 AffineMap mlir::removeDuplicateExprs(AffineMap map) { 404 auto results = map.getResults(); 405 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end()); 406 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()), 407 uniqueExprs.end()); 408 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs, 409 map.getContext()); 410 } 411 412 AffineMap mlir::inversePermutation(AffineMap map) { 413 if (map.isEmpty()) 414 return map; 415 assert(map.getNumSymbols() == 0 && "expected map without symbols"); 416 SmallVector<AffineExpr, 4> exprs(map.getNumDims()); 417 for (auto en : llvm::enumerate(map.getResults())) { 418 auto expr = en.value(); 419 // Skip non-permutations. 420 if (auto d = expr.dyn_cast<AffineDimExpr>()) { 421 if (exprs[d.getPosition()]) 422 continue; 423 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); 424 } 425 } 426 SmallVector<AffineExpr, 4> seenExprs; 427 seenExprs.reserve(map.getNumDims()); 428 for (auto expr : exprs) 429 if (expr) 430 seenExprs.push_back(expr); 431 if (seenExprs.size() != map.getNumInputs()) 432 return AffineMap(); 433 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); 434 } 435 436 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { 437 unsigned numResults = 0, numDims = 0, numSymbols = 0; 438 for (auto m : maps) 439 numResults += m.getNumResults(); 440 SmallVector<AffineExpr, 8> results; 441 results.reserve(numResults); 442 for (auto m : maps) { 443 for (auto res : m.getResults()) 444 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); 445 446 numSymbols += m.getNumSymbols(); 447 numDims = std::max(m.getNumDims(), numDims); 448 } 449 return AffineMap::get(numDims, numSymbols, results, 450 maps.front().getContext()); 451 } 452 453 //===----------------------------------------------------------------------===// 454 // MutableAffineMap. 455 //===----------------------------------------------------------------------===// 456 457 MutableAffineMap::MutableAffineMap(AffineMap map) 458 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), 459 context(map.getContext()) { 460 for (auto result : map.getResults()) 461 results.push_back(result); 462 } 463 464 void MutableAffineMap::reset(AffineMap map) { 465 results.clear(); 466 numDims = map.getNumDims(); 467 numSymbols = map.getNumSymbols(); 468 context = map.getContext(); 469 for (auto result : map.getResults()) 470 results.push_back(result); 471 } 472 473 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 474 if (results[idx].isMultipleOf(factor)) 475 return true; 476 477 // TODO: use simplifyAffineExpr and FlatAffineConstraints to 478 // complete this (for a more powerful analysis). 479 return false; 480 } 481 482 // Simplifies the result affine expressions of this map. The expressions have to 483 // be pure for the simplification implemented. 484 void MutableAffineMap::simplify() { 485 // Simplify each of the results if possible. 486 // TODO: functional-style map 487 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 488 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 489 } 490 } 491 492 AffineMap MutableAffineMap::getAffineMap() const { 493 return AffineMap::get(numDims, numSymbols, results, context); 494 } 495