1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filters(), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) { 60 if (f) 61 filters.push_back(f); 62 } 63 64 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 65 PatternRewriter &rewriter, Operation *op) const { 66 if (llvm::any_of(filters, 67 [&](const FilterFunction &f) { return failed(f(op)); })) 68 return failure(); 69 70 auto attr = op->template getAttrOfType<StringAttr>( 71 LinalgTransforms::kLinalgTransformMarker); 72 73 if (!attr) { 74 // 1. Has no filter case and matchDisjunction is empty. 75 if (matchDisjunction.empty()) 76 return success(); 77 78 // 2. Has no filter but was expecting a filter. 79 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 80 diag << " does not have any filter from list: "; 81 interleaveComma(matchDisjunction, diag); 82 }); 83 } 84 85 // 4. Match explicit filter. 86 for (auto filter : matchDisjunction) 87 if (attr.getValue() == filter) 88 return success(); 89 90 // 5. Fail to match. 91 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 92 diag << " does not have any filter from list: "; 93 interleaveComma(matchDisjunction, diag); 94 }); 95 } 96 97 void mlir::linalg::LinalgTransformationFilter:: 98 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 99 Operation *op) const { 100 if (replacement.hasValue()) 101 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 102 rewriter.getStringAttr(replacement.getValue())); 103 else 104 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 105 rewriter.getContext())); 106 } 107 108 LinalgTilingOptions & 109 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 110 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 111 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 112 OpBuilder::InsertionGuard guard(b); 113 b.setInsertionPointToStart( 114 &op->getParentOfType<FuncOp>().getBody().front()); 115 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 116 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 117 return v; 118 })); 119 }; 120 return *this; 121 } 122 123 /// Try to compute a static bounding box for `operand` 124 /// Return success if either: 125 /// 1. The operand is already statically shaped, `result` is left unchanged. 126 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 127 /// created PadTensorOp. 128 /// Return failure if the operand cannot be padded to a static shape. 129 static LogicalResult padOperandToSmallestStaticBoundingBox( 130 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand, 131 const LinalgTilingOptions &options, Value &result) { 132 auto tensorType = operand.get().getType().cast<RankedTensorType>(); 133 // Already static shape, no need to pad. 134 if (tensorType.hasStaticShape()) 135 return success(); 136 auto subtensor = operand.get().getDefiningOp<SubTensorOp>(); 137 // Not a subtensor, cannot construct a static bounding box. 138 if (!subtensor) 139 return failure(); 140 SmallVector<int64_t> staticSizes; 141 staticSizes.reserve(tensorType.getRank()); 142 auto shapedOp = 143 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 144 for (auto size : shapedOp.getMixedSizes()) { 145 auto indexAttr = size.is<Attribute>() 146 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 147 : linalg::getSmallestBoundingIndex(size.get<Value>()); 148 // SmallestBoundingIndex must exist for all sizes. 149 // For now return an error if we can't find it. 150 if (!indexAttr) 151 return rewriter.notifyMatchFailure( 152 opToPad, "No constant bounding box can be found for padding"); 153 staticSizes.push_back(indexAttr.getInt()); 154 } 155 Value pad = options.paddingValueComputationFunction(rewriter, operand); 156 auto staticTensorType = 157 RankedTensorType::get(staticSizes, tensorType.getElementType()); 158 result = linalg::PadTensorOp::createPadHighOp( 159 staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter); 160 return success(); 161 } 162 163 // Try to create a static bounding box around each operand of `res.op`. 164 // If successful, `res.op` is rewritten in static form with padded operands. 165 // `res.op` is updated to the cloned static form of the op on success. 166 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 167 TiledLinalgOp &res, 168 const LinalgTilingOptions &options) { 169 LinalgOp opToPad = res.op; 170 Location loc = opToPad->getLoc(); 171 172 // If the op is fully static, it does not need padding. 173 // TODO: there are cases where we may still want to pad to larger sizes. 174 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 175 return v.getType().cast<RankedTensorType>().hasStaticShape(); 176 })) 177 return success(); 178 179 OpBuilder::InsertionGuard g(rewriter); 180 // Set IP after op because we also take the dims of the original output. 181 rewriter.setInsertionPointAfter(opToPad); 182 // Make a copy of the shaped operands and update it. 183 SmallVector<Value> newOperands; 184 newOperands.reserve(opToPad.getNumShapedOperands()); 185 for (OpOperand &operand : opToPad.getShapedOpOperands()) { 186 Value paddedOperand; 187 // If padding was requested but the shape cannot be bounded statically then 188 // the pattern fails to apply. 189 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand, 190 options, paddedOperand))) { 191 return failure(); 192 } 193 newOperands.push_back(paddedOperand ? paddedOperand : operand.get()); 194 } 195 196 // Clone `opToPad` to operate on the statically padded shapes. 197 auto resultTensorTypes = 198 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 199 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 200 newOperands.append(otherOperands.begin(), otherOperands.end()); 201 linalg::LinalgOp paddedOp = 202 opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 203 204 // Recover the subtensor out of the new static results. This keeps the 205 // original linalg op around because it uses the dims of the original results. 206 // This later folds away. 207 SmallVector<Value> paddedSubviewResults; 208 paddedSubviewResults.reserve(opToPad->getNumResults()); 209 llvm::SetVector<Operation *> newUsersOfOpToPad; 210 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 211 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 212 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 213 auto sizes = llvm::to_vector<4>(llvm::map_range( 214 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 215 auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d); 216 newUsersOfOpToPad.insert(dimOp); 217 return dimOp.getResult(); 218 })); 219 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 220 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 221 loc, std::get<1>(it), offsets, sizes, strides)); 222 } 223 // Replace the transient `opToPad` locally, except for uses that we just 224 // created for the purpose of extracting the dims. 225 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 226 return !newUsersOfOpToPad.contains(opOp.getOwner()); 227 }); 228 229 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 230 return success(); 231 } 232 233 /// Linalg base tiling pattern. 234 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 235 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 236 LinalgTransformationFilter filter, PatternBenefit benefit) 237 : RewritePattern(opName, benefit, context), filter(filter), 238 options(options) {} 239 240 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 241 MLIRContext *context, LinalgTilingOptions options, 242 LinalgTransformationFilter filter, PatternBenefit benefit) 243 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 244 options(options) {} 245 246 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 247 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 248 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 249 // TODO: remove hasIndexSemantics check once index ops are supported. 250 if (!linalgOp || linalgOp.hasIndexSemantics()) 251 return failure(); 252 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 253 return failure(); 254 255 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 256 257 if (!res) 258 return failure(); 259 260 // Setup RAII guard to return properly. 261 bool succeeded = true; 262 LinalgOp tiledOp = res->op; 263 auto guard = llvm::make_scope_exit([&]() { 264 if (!succeeded) 265 return; 266 // Return relevant information to derived pattern. 267 result = *res; 268 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 269 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 270 if (tiledOp != res->op) 271 filter.replaceLinalgTransformationFilter(rewriter, res->op); 272 }); 273 274 // Consider padding on the fly only if the op has tensor semantics. 275 if (!options.paddingValueComputationFunction || 276 !linalgOp.hasTensorSemantics()) 277 return success(); 278 279 // Try to pad on the fly by rewriting res->op as a padded op. 280 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 281 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 282 succeeded = false; 283 return failure(); 284 } 285 286 // Do not perform replacement of `linalgOp`, let the derived patterns 287 // do this as they see fit, from the resulting TiledLinalgOp. 288 return success(); 289 } 290 291 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 292 if (tiledOp.loops.empty()) 293 return tiledOp.op.getOperation()->getResults(); 294 return tiledOp.loops.front()->getResults(); 295 } 296 297 static ValueRange 298 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 299 if (tiledAndFusedOp.fusedLoops.empty()) 300 return tiledAndFusedOp.op.getOperation()->getResults(); 301 return tiledAndFusedOp.fusedLoops.front()->getResults(); 302 } 303 304 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 305 StringRef opName, MLIRContext *context, 306 const LinalgDependenceGraph &dependenceGraph, 307 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 308 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 309 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 310 : RewritePattern(opName, benefit, context, {}), 311 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 312 fusionOptions(fusionOptions), filter(filter), 313 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 314 315 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 316 Operation *op, PatternRewriter &rewriter) const { 317 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 318 // TODO: remove hasIndexSemantics check once index ops are supported. 319 if (!linalgOp || linalgOp.hasIndexSemantics()) 320 return failure(); 321 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 322 return failure(); 323 324 DenseSet<Operation *> producers; 325 producers.insert(linalgOp); 326 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 327 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 328 // When looking at dependences into, indexingOp is always OpOperand. We 329 // could assert, but continue if this is not the case. 330 if (!operandNumber) 331 continue; 332 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 333 continue; 334 if (isa<LinalgOp>(dependence.getDependentOp())) 335 producers.insert(dependence.getDependentOp()); 336 } 337 338 SmallVector<LinalgOp, 1> fusionOps; 339 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 340 ++it) { 341 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 342 if (producerLinalgOp && producers.count(producerLinalgOp)) 343 fusionOps.push_back(producerLinalgOp); 344 } 345 fusionOps.push_back(linalgOp); 346 347 SmallVector<Value, 4> tileSizes = 348 tilingOptions.tileSizeComputationFunction(rewriter, op); 349 LinalgTilingOptions instanceTilingOptions = tilingOptions; 350 instanceTilingOptions.setTileSizes(tileSizes); 351 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 352 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 353 if (!tiledAndFusedOps) 354 return failure(); 355 356 // Tile the unfused loops; 357 SmallVector<Value, 4> unfusedLoopTileSizes; 358 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 359 for (auto tileSize : enumerate(tileSizes)) { 360 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 361 unfusedLoopTileSizes.push_back(zero); 362 else 363 unfusedLoopTileSizes.push_back(tileSize.value()); 364 } 365 // Tile the loop only if there is a non-zero tile size. 366 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 367 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 368 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 369 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 370 return cst.getValue() != 0; 371 return true; 372 })) { 373 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 374 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 375 Optional<TiledLinalgOp> unfusedTiledOp = 376 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 377 if (!unfusedTiledOp) 378 return failure(); 379 rewriter.replaceOp(tiledAndFusedOps->op, 380 getTiledOpResult(unfusedTiledOp.getValue())); 381 tiledAndFusedOps->op = unfusedTiledOp->op; 382 } 383 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 384 385 filter.replaceLinalgTransformationFilter(rewriter, 386 tiledAndFusedOps->op.getOperation()); 387 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 388 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 389 fusedOp.getOperation()); 390 } 391 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 392 originalOpMarker.replaceLinalgTransformationFilter( 393 rewriter, origProducerOp.getOperation()); 394 } 395 rewriter.updateRootInPlace(op, [&]() { 396 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 397 }); 398 return success(); 399 } 400 401 /// Linalg base interchange pattern. 402 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 403 StringRef opName, MLIRContext *context, 404 ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter, 405 PatternBenefit benefit) 406 : RewritePattern(opName, benefit, context, {}), filter(filter), 407 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 408 409 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 410 Operation *op, PatternRewriter &rewriter) const { 411 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 412 // TODO: remove hasIndexSemantics check once index ops are supported. 413 if (!linalgOp || linalgOp.hasIndexSemantics()) 414 return failure(); 415 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 416 return failure(); 417 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 418 return failure(); 419 420 // TODO: figure out how this interplays with named ops. In particular this 421 // should break the named op property. 422 rewriter.updateRootInPlace(op, [&]() { 423 interchange(linalgOp, interchangeVector); 424 // New filter if specified. 425 filter.replaceLinalgTransformationFilter(rewriter, op); 426 }); 427 return success(); 428 } 429 430 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 431 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 432 LinalgTransformationFilter filter, PatternBenefit benefit) 433 : RewritePattern(opName, benefit, context, {}), filter(filter), 434 options(options) {} 435 436 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 437 Operation *op, PatternRewriter &rewriter) const { 438 if (failed(filter.checkAndNotify(rewriter, op))) 439 return failure(); 440 if (failed(promoteSubviewsPrecondition(op, options))) 441 return failure(); 442 443 // TODO: We cannot use root update here. This pattern is creating other ops, 444 // so if the promotion fails, those need to be cleaned up, which doesnt seem 445 // to be happening here. So to fail properly, we should be cloning the op and 446 // deleting the previous op. This needs more investigation. 447 rewriter.startRootUpdate(op); 448 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 449 if (!promotedOp) { 450 rewriter.cancelRootUpdate(op); 451 return op->emitError("subview promotion failed"); 452 } 453 rewriter.finalizeRootUpdate(op); 454 filter.replaceLinalgTransformationFilter(rewriter, op); 455 return success(); 456 } 457 458 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 459 MLIRContext *context, LinalgTransformationFilter filter, 460 PatternBenefit benefit) 461 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 462 463 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 464 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 465 PatternBenefit benefit) 466 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 467 468 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 469 Operation *op, PatternRewriter &rewriter) const { 470 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 471 // TODO: remove hasIndexSemantics check once index ops are supported. 472 if (!linalgOp || linalgOp.hasIndexSemantics()) 473 return failure(); 474 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 475 return failure(); 476 SmallVector<Value> newResults; 477 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 478 return failure(); 479 if (!newResults.empty()) 480 rewriter.replaceOp(op, newResults); 481 else 482 rewriter.eraseOp(op); 483 return success(); 484 } 485 486 LogicalResult mlir::linalg::applyStagedPatterns( 487 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 488 const FrozenRewritePatternSet &stage2Patterns, 489 function_ref<LogicalResult(Operation *)> stage3Lambda) { 490 unsigned iteration = 0; 491 (void)iteration; 492 for (const auto &patterns : stage1Patterns) { 493 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 494 << *op); 495 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 496 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 497 return failure(); 498 } 499 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 500 << *op); 501 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 502 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 503 return failure(); 504 } 505 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 506 << *op); 507 if (stage3Lambda) { 508 if (failed(stage3Lambda(op))) 509 return failure(); 510 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 511 << *op); 512 } 513 } 514 return success(); 515 } 516 517 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 518 /// `ubVal` to `dims` and `stepVal` to `symbols`. 519 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 520 /// with positions matching the newly appended values. Substitute occurrences of 521 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 522 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 523 /// the induction variable is used with a positive or negative coefficient. 524 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 525 Value lbVal, Value ubVal, Value stepVal, 526 SmallVectorImpl<Value> &dims, 527 SmallVectorImpl<Value> &symbols) { 528 MLIRContext *ctx = lbVal.getContext(); 529 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 530 dims.push_back(lbVal); 531 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 532 dims.push_back(ubVal); 533 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 534 symbols.push_back(stepVal); 535 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 536 AffineExpr ee = substWithMin(expr, dimExpr, lb, 537 lb + step * ((ub - 1) - lb).floorDiv(step)); 538 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 539 return ee; 540 } 541 542 /// Traverse the `dims` and substitute known min or max expressions in place of 543 /// induction variables in `exprs`. 544 static AffineMap substitute( 545 AffineMap map, SmallVectorImpl<Value> &dims, 546 SmallVectorImpl<Value> &symbols, 547 llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) { 548 auto exprs = llvm::to_vector<4>(map.getResults()); 549 for (AffineExpr &expr : exprs) { 550 bool substituted = true; 551 while (substituted) { 552 substituted = false; 553 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 554 Value dim = dims[dimIdx]; 555 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 556 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 557 AffineExpr substitutedExpr; 558 if (auto forOp = scf::getForInductionVarOwner(dim)) 559 if (!substituteOperation || substituteOperation(forOp)) 560 substitutedExpr = substituteLoopInExpr( 561 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 562 forOp.step(), dims, symbols); 563 564 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 565 if (!substituteOperation || substituteOperation(parallelForOp)) 566 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 567 ++idx) 568 substitutedExpr = substituteLoopInExpr( 569 expr, dimExpr, parallelForOp.lowerBound()[idx], 570 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 571 dims, symbols); 572 573 if (!substitutedExpr) 574 continue; 575 576 substituted = (substitutedExpr != expr); 577 expr = substitutedExpr; 578 } 579 } 580 581 // Cleanup and simplify the results. 582 // This needs to happen outside of the loop iterating on dims.size() since 583 // it modifies dims. 584 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 585 operands.append(symbols.begin(), symbols.end()); 586 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 587 exprs.front().getContext()); 588 589 LLVM_DEBUG({ 590 DBGS() << "Map to simplify: " << map << "\n"; 591 DBGS() << "Operands:\n"; 592 for (Value v : operands) 593 DBGS() << v << "\n"; 594 }); 595 596 // Pull in affine.apply operations and compose them fully into the 597 // result. 598 fullyComposeAffineMapAndOperands(&map, &operands); 599 canonicalizeMapAndOperands(&map, &operands); 600 map = simplifyAffineMap(map); 601 // Assign the results. 602 exprs.assign(map.getResults().begin(), map.getResults().end()); 603 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 604 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 605 606 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 607 } 608 609 assert(!exprs.empty() && "Unexpected empty exprs"); 610 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 611 } 612 613 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop 614 /// induction variables by new expressions involving the lower or upper bound: 615 /// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is 616 /// replaced by the loop upper bound. 617 /// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is 618 /// replaced by the loop lower bound. 619 /// All loop induction variables are iteratively replaced, unless a 620 /// `substituteOperation` hook is passed to more finely determine which 621 /// operations are substituted. 622 /// This is used as an intermediate step in computing bounding boxes and 623 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have 624 /// positive values (positive orthant assumptions). 625 /// Return a new AffineMap, dims and symbols that have been canonicalized and 626 /// simplified. 627 AffineMapAndOperands mlir::linalg::substituteMin( 628 AffineMinOp affineMinOp, 629 llvm::function_ref<bool(Operation *)> substituteOperation) { 630 AffineMapAndOperands res{affineMinOp.getAffineMap(), 631 SmallVector<Value>(affineMinOp.getDimOperands()), 632 SmallVector<Value>(affineMinOp.getSymbolOperands())}; 633 res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, 634 substituteOperation); 635 return res; 636 } 637 638 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 639 AffineMinOp minOp, PatternRewriter &rewriter) const { 640 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 641 << "\n"); 642 643 auto affineMapAndOperands = substituteMin(minOp); 644 AffineMap map = affineMapAndOperands.map; 645 646 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 647 648 // Check whether any of the expressions, when subtracted from all other 649 // expressions, produces only >= 0 constants. If so, it is the min. 650 for (auto e : minOp.getAffineMap().getResults()) { 651 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 652 if (!e.isSymbolicOrConstant()) 653 continue; 654 655 auto isNonPositive = [](AffineExpr e) { 656 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 657 return cst.getValue() < 0; 658 return true; 659 }; 660 661 // Build the subMap and check everything is statically known to be 662 // positive. 663 SmallVector<AffineExpr, 4> subExprs; 664 subExprs.reserve(map.getNumResults()); 665 for (auto ee : map.getResults()) 666 subExprs.push_back(ee - e); 667 MLIRContext *ctx = minOp.getContext(); 668 AffineMap subMap = simplifyAffineMap( 669 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 670 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 671 if (llvm::any_of(subMap.getResults(), isNonPositive)) 672 continue; 673 674 // Static min found. 675 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 676 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 677 } else { 678 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 679 SmallVector<Value> resultOperands = affineMapAndOperands.dims; 680 llvm::append_range(resultOperands, affineMapAndOperands.symbols); 681 canonicalizeMapAndOperands(&resultMap, &resultOperands); 682 resultMap = simplifyAffineMap(resultMap); 683 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 684 resultOperands); 685 } 686 return success(); 687 } 688 689 return failure(); 690 } 691