1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : LinalgTransformationFilter([](Operation *) { return success(); }, 52 matchDisjunction, replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filter(f), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) {} 60 61 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 62 PatternRewriter &rewriter, Operation *op) const { 63 if (filter && failed(filter(op))) 64 return failure(); 65 66 auto attr = op->template getAttrOfType<StringAttr>( 67 LinalgTransforms::kLinalgTransformMarker); 68 69 if (!attr) { 70 // 1. Has no marker case and matchDisjunction is empty. 71 if (matchDisjunction.empty()) 72 return success(); 73 74 // 2. Has no marker but was expecting a marker. 75 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 76 diag << " does not have any marker from list: "; 77 interleaveComma(matchDisjunction, diag); 78 }); 79 } 80 81 // 4. Match explicit marker. 82 for (auto marker : matchDisjunction) 83 if (attr.getValue() == marker) 84 return success(); 85 86 // 5. Fail to match. 87 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 88 diag << " does not have any marker from list: "; 89 interleaveComma(matchDisjunction, diag); 90 }); 91 } 92 93 void mlir::linalg::LinalgTransformationFilter:: 94 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 95 Operation *op) const { 96 if (replacement.hasValue()) 97 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getStringAttr(replacement.getValue())); 99 else 100 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 101 rewriter.getContext())); 102 } 103 104 LinalgTilingOptions & 105 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 106 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 107 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 108 OpBuilder::InsertionGuard guard(b); 109 b.setInsertionPointToStart( 110 &op->getParentOfType<FuncOp>().getBody().front()); 111 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 112 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 113 return v; 114 })); 115 }; 116 return *this; 117 } 118 119 /// Try to compute a static bounding box for `operand` 120 /// Return success if either: 121 /// 1. The operand is already statically shaped, `result` is left unchanged. 122 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 123 /// created PadTensorOp. 124 /// Return failure if the operand cannot be padded to a static shape. 125 static LogicalResult padOperandToSmallestStaticBoundingBox( 126 PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, 127 const LinalgTilingOptions &options, Value &result) { 128 auto tensorType = operand.getType().cast<RankedTensorType>(); 129 // Already static shape, no need to pad. 130 if (tensorType.hasStaticShape()) 131 return success(); 132 auto subtensor = operand.getDefiningOp<SubTensorOp>(); 133 // Not a subtensor, cannot construct a static bounding box. 134 if (!subtensor) 135 return failure(); 136 SmallVector<int64_t> staticSizes; 137 staticSizes.reserve(tensorType.getRank()); 138 auto shapedOp = 139 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = options.paddingValueComputationFunction(rewriter, opToPad); 152 auto staticTensorType = 153 RankedTensorType::get(staticSizes, tensorType.getElementType()); 154 result = linalg::PadTensorOp::createPadHighOp(staticTensorType, operand, pad, 155 opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 // Try to create a static bounding box around each operand of `res.op`. 160 // If successful, `res.op` is rewritten in static form with padded operands. 161 // `res.op` is updated to the cloned static form of the op on success. 162 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 163 TiledLinalgOp &res, 164 const LinalgTilingOptions &options) { 165 LinalgOp opToPad = res.op; 166 Location loc = opToPad->getLoc(); 167 168 // If the op is fully static, it does not need padding. 169 // TODO: there are cases where we may still want to pad to larger sizes. 170 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 171 return v.getType().cast<RankedTensorType>().hasStaticShape(); 172 })) 173 return success(); 174 175 OpBuilder::InsertionGuard g(rewriter); 176 // Set IP after op because we also take the dims of the original output. 177 rewriter.setInsertionPointAfter(opToPad); 178 // Make a copy of the shaped operands and update it. 179 SmallVector<Value> operands = opToPad.getShapedOperands(); 180 for (Value &v : operands) { 181 Value paddedOperand; 182 // If padding was requested but the shape cannot be bounded statically then 183 // the pattern fails to apply. 184 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, 185 options, paddedOperand))) { 186 return failure(); 187 } 188 // Update v if we indeed got a padded operand. 189 v = paddedOperand ? paddedOperand : v; 190 } 191 192 // Clone `opToPad` to operate on the statically padded shapes. 193 auto resultTensorTypes = 194 ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); 195 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 196 operands.append(otherOperands.begin(), otherOperands.end()); 197 linalg::LinalgOp paddedOp = 198 opToPad.clone(rewriter, loc, resultTensorTypes, operands); 199 200 // Recover the subtensor out of the new static results. This keeps the 201 // original linalg op around because it uses the dims of the original results. 202 // This later folds away. 203 SmallVector<Value> paddedSubviewResults; 204 paddedSubviewResults.reserve(opToPad->getNumResults()); 205 llvm::SetVector<Operation *> newUsersOfOpToPad; 206 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 207 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 208 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 209 auto sizes = llvm::to_vector<4>(llvm::map_range( 210 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 211 auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d); 212 newUsersOfOpToPad.insert(dimOp); 213 return dimOp.getResult(); 214 })); 215 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 216 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 217 loc, std::get<1>(it), offsets, sizes, strides)); 218 } 219 // Replace the transient `opToPad` locally, except for uses that we just 220 // created for the purpose of extracting the dims. 221 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 222 return !newUsersOfOpToPad.contains(opOp.getOwner()); 223 }); 224 225 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 226 return success(); 227 } 228 229 /// Linalg base tiling pattern. 230 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 231 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 232 LinalgTransformationFilter marker, PatternBenefit benefit) 233 : RewritePattern(opName, {}, benefit, context), marker(marker), 234 options(options) {} 235 236 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 237 LinalgTilingOptions options, LinalgTransformationFilter marker, 238 PatternBenefit benefit) 239 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), 240 options(options) {} 241 242 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 243 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 244 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 245 if (!linalgOp) 246 return failure(); 247 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 248 return failure(); 249 250 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 251 252 if (!res) 253 return failure(); 254 255 // Setup RAII guard to return properly. 256 bool succeeded = true; 257 LinalgOp tiledOp = res->op; 258 auto guard = llvm::make_scope_exit([&]() { 259 if (!succeeded) 260 return; 261 // Return relevant information to derived pattern. 262 result = *res; 263 // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary. 264 marker.replaceLinalgTransformationFilter(rewriter, tiledOp); 265 if (tiledOp != res->op) 266 marker.replaceLinalgTransformationFilter(rewriter, res->op); 267 }); 268 269 // Consider padding on the fly only if the op has tensor semantics. 270 if (!options.paddingValueComputationFunction || 271 !linalgOp.hasTensorSemantics()) 272 return success(); 273 274 // Try to pad on the fly by rewriting res->op as a padded op. 275 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 276 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 277 succeeded = false; 278 return failure(); 279 } 280 281 // Do not perform replacement of `linalgOp`, let the derived patterns 282 // do this as they see fit, from the resulting TiledLinalgOp. 283 return success(); 284 } 285 286 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 287 if (tiledOp.loops.empty()) 288 return tiledOp.op.getOperation()->getResults(); 289 return tiledOp.loops.front()->getResults(); 290 } 291 292 static ValueRange 293 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 294 if (tiledAndFusedOp.fusedLoops.empty()) 295 return tiledAndFusedOp.op.getOperation()->getResults(); 296 return tiledAndFusedOp.fusedLoops.front()->getResults(); 297 } 298 299 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 300 StringRef opName, MLIRContext *context, 301 const LinalgDependenceGraph &dependenceGraph, 302 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 303 LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker, 304 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 305 : RewritePattern(opName, {}, benefit, context), 306 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 307 fusionOptions(fusionOptions), marker(marker), 308 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 309 310 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 311 Operation *op, PatternRewriter &rewriter) const { 312 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 313 if (!linalgOp) 314 return failure(); 315 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 316 return failure(); 317 318 DenseSet<Operation *> producers; 319 producers.insert(linalgOp); 320 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 321 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 322 // When looking at dependences into, indexingOp is always OpOperand. We 323 // could assert, but continue if this is not the case. 324 if (!operandNumber) 325 continue; 326 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 327 continue; 328 if (isa<LinalgOp>(dependence.getDependentOp())) 329 producers.insert(dependence.getDependentOp()); 330 } 331 332 SmallVector<LinalgOp, 1> fusionOps; 333 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 334 ++it) { 335 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 336 if (producerLinalgOp && producers.count(producerLinalgOp)) 337 fusionOps.push_back(producerLinalgOp); 338 } 339 fusionOps.push_back(linalgOp); 340 341 SmallVector<Value, 4> tileSizes = 342 tilingOptions.tileSizeComputationFunction(rewriter, op); 343 LinalgTilingOptions instanceTilingOptions = tilingOptions; 344 instanceTilingOptions.setTileSizes(tileSizes); 345 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 346 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 347 if (!tiledAndFusedOps) 348 return failure(); 349 350 // Tile the unfused loops; 351 SmallVector<Value, 4> unfusedLoopTileSizes; 352 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 353 for (auto tileSize : enumerate(tileSizes)) { 354 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 355 unfusedLoopTileSizes.push_back(zero); 356 else 357 unfusedLoopTileSizes.push_back(tileSize.value()); 358 } 359 // Tile the loop only if there is a non-zero tile size. 360 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 361 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 362 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 363 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 364 return cst.getValue() != 0; 365 return true; 366 })) { 367 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 368 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 369 Optional<TiledLinalgOp> unfusedTiledOp = 370 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 371 if (!unfusedTiledOp) 372 return failure(); 373 rewriter.replaceOp(tiledAndFusedOps->op, 374 getTiledOpResult(unfusedTiledOp.getValue())); 375 tiledAndFusedOps->op = unfusedTiledOp->op; 376 } 377 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 378 379 marker.replaceLinalgTransformationFilter(rewriter, 380 tiledAndFusedOps->op.getOperation()); 381 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 382 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 383 fusedOp.getOperation()); 384 } 385 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 386 originalOpMarker.replaceLinalgTransformationFilter( 387 rewriter, origProducerOp.getOperation()); 388 } 389 rewriter.updateRootInPlace(op, [&]() { 390 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 391 }); 392 return success(); 393 } 394 395 /// Linalg base interchange pattern. 396 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 397 StringRef opName, MLIRContext *context, 398 ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker, 399 PatternBenefit benefit) 400 : RewritePattern(opName, {}, benefit, context), marker(marker), 401 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 402 403 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 404 Operation *op, PatternRewriter &rewriter) const { 405 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 406 if (!linalgOp) 407 return failure(); 408 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 409 return failure(); 410 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 411 return failure(); 412 413 // TODO: figure out how this interplays with named ops. In particular this 414 // should break the named op property. 415 rewriter.updateRootInPlace(op, [&]() { 416 interchange(linalgOp, interchangeVector); 417 // New marker if specified. 418 marker.replaceLinalgTransformationFilter(rewriter, op); 419 }); 420 return success(); 421 } 422 423 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 424 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 425 LinalgTransformationFilter marker, PatternBenefit benefit) 426 : RewritePattern(opName, {}, benefit, context), marker(marker), 427 options(options) {} 428 429 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 430 Operation *op, PatternRewriter &rewriter) const { 431 if (failed(marker.checkAndNotify(rewriter, op))) 432 return failure(); 433 if (failed(promoteSubviewsPrecondition(op, options))) 434 return failure(); 435 436 // TODO: We cannot use root update here. This pattern is creating other ops, 437 // so if the promotion fails, those need to be cleaned up, which doesnt seem 438 // to be happening here. So to fail properly, we should be cloning the op and 439 // deleting the previous op. This needs more investigation. 440 rewriter.startRootUpdate(op); 441 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 442 if (!promotedOp) { 443 rewriter.cancelRootUpdate(op); 444 return op->emitError("subview promotion failed"); 445 } 446 rewriter.finalizeRootUpdate(op); 447 marker.replaceLinalgTransformationFilter(rewriter, op); 448 return success(); 449 } 450 451 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 452 StringRef opName, MLIRContext *context, LinalgTransformationFilter marker, 453 PatternBenefit benefit) 454 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 455 456 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 457 Operation *op, PatternRewriter &rewriter) const { 458 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 459 if (!linalgOp) 460 return failure(); 461 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 462 return failure(); 463 if (failed(vectorizeLinalgOpPrecondition(op))) 464 return failure(); 465 vectorizeLinalgOp(rewriter, op); 466 rewriter.eraseOp(op); 467 return success(); 468 } 469 470 LogicalResult mlir::linalg::applyStagedPatterns( 471 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 472 const FrozenRewritePatternList &stage2Patterns, 473 function_ref<LogicalResult(Operation *)> stage3Lambda) { 474 unsigned iteration = 0; 475 (void)iteration; 476 for (const auto &patterns : stage1Patterns) { 477 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 478 << *op); 479 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 480 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 481 return failure(); 482 } 483 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 484 << *op); 485 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 486 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 487 return failure(); 488 } 489 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 490 << *op); 491 if (stage3Lambda) { 492 if (failed(stage3Lambda(op))) 493 return failure(); 494 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 495 << *op); 496 } 497 } 498 return success(); 499 } 500 501 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 502 /// `ubVal` to `dims` and `stepVal` to `symbols`. 503 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 504 /// with positions matching the newly appended values. Substitute occurrences of 505 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 506 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 507 /// the induction variable is used with a positive or negative coefficient. 508 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 509 Value lbVal, Value ubVal, Value stepVal, 510 SmallVectorImpl<Value> &dims, 511 SmallVectorImpl<Value> &symbols) { 512 MLIRContext *ctx = lbVal.getContext(); 513 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 514 dims.push_back(lbVal); 515 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 516 dims.push_back(ubVal); 517 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 518 symbols.push_back(stepVal); 519 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 520 AffineExpr ee = substWithMin(expr, dimExpr, lb, 521 lb + step * ((ub - 1) - lb).floorDiv(step)); 522 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 523 return ee; 524 } 525 526 /// Traverse the `dims` and substitute known min or max expressions in place of 527 /// induction variables in `exprs`. 528 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 529 SmallVectorImpl<Value> &symbols) { 530 auto exprs = llvm::to_vector<4>(map.getResults()); 531 for (AffineExpr &expr : exprs) { 532 bool substituted = true; 533 while (substituted) { 534 substituted = false; 535 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 536 Value dim = dims[dimIdx]; 537 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 538 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 539 AffineExpr substitutedExpr; 540 if (auto forOp = scf::getForInductionVarOwner(dim)) 541 substitutedExpr = substituteLoopInExpr( 542 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 543 forOp.step(), dims, symbols); 544 545 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 546 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 547 ++idx) 548 substitutedExpr = substituteLoopInExpr( 549 expr, dimExpr, parallelForOp.lowerBound()[idx], 550 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 551 dims, symbols); 552 553 if (!substitutedExpr) 554 continue; 555 556 substituted = (substitutedExpr != expr); 557 expr = substitutedExpr; 558 } 559 } 560 561 // Cleanup and simplify the results. 562 // This needs to happen outside of the loop iterating on dims.size() since 563 // it modifies dims. 564 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 565 operands.append(symbols.begin(), symbols.end()); 566 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 567 exprs.front().getContext()); 568 569 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 570 571 // Pull in affine.apply operations and compose them fully into the 572 // result. 573 fullyComposeAffineMapAndOperands(&map, &operands); 574 canonicalizeMapAndOperands(&map, &operands); 575 map = simplifyAffineMap(map); 576 // Assign the results. 577 exprs.assign(map.getResults().begin(), map.getResults().end()); 578 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 579 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 580 581 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 582 } 583 584 assert(!exprs.empty() && "Unexpected empty exprs"); 585 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 586 } 587 588 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 589 AffineMinOp minOp, PatternRewriter &rewriter) const { 590 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 591 << "\n"); 592 593 SmallVector<Value, 4> dims(minOp.getDimOperands()), 594 symbols(minOp.getSymbolOperands()); 595 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 596 597 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 598 599 // Check whether any of the expressions, when subtracted from all other 600 // expressions, produces only >= 0 constants. If so, it is the min. 601 for (auto e : minOp.getAffineMap().getResults()) { 602 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 603 if (!e.isSymbolicOrConstant()) 604 continue; 605 606 auto isNonPositive = [](AffineExpr e) { 607 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 608 return cst.getValue() < 0; 609 return true; 610 }; 611 612 // Build the subMap and check everything is statically known to be 613 // positive. 614 SmallVector<AffineExpr, 4> subExprs; 615 subExprs.reserve(map.getNumResults()); 616 for (auto ee : map.getResults()) 617 subExprs.push_back(ee - e); 618 MLIRContext *ctx = minOp.getContext(); 619 AffineMap subMap = simplifyAffineMap( 620 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 621 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 622 if (llvm::any_of(subMap.getResults(), isNonPositive)) 623 continue; 624 625 // Static min found. 626 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 627 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 628 } else { 629 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 630 SmallVector<Value, 4> resultOperands = dims; 631 resultOperands.append(symbols.begin(), symbols.end()); 632 canonicalizeMapAndOperands(&resultMap, &resultOperands); 633 resultMap = simplifyAffineMap(resultMap); 634 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 635 resultOperands); 636 } 637 return success(); 638 } 639 640 return failure(); 641 } 642