1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/ScopeExit.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <type_traits> 31 32 #define DEBUG_TYPE "linalg-transforms" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 40 41 //===----------------------------------------------------------------------===// 42 // Transformations exposed as rewrite patterns. 43 //===----------------------------------------------------------------------===// 44 // Marker used as attribute name in generated Linalg rewriting transformations. 45 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 46 "__internal_linalg_transform__"; 47 48 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 49 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 50 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 51 replacement(replacement) {} 52 53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 54 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 55 Optional<Identifier> replacement) 56 : filters(), 57 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 58 replacement(replacement) { 59 if (f) 60 filters.push_back(f); 61 } 62 63 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 64 PatternRewriter &rewriter, Operation *op) const { 65 if (llvm::any_of(filters, 66 [&](const FilterFunction &f) { return failed(f(op)); })) 67 return failure(); 68 69 auto attr = op->template getAttrOfType<StringAttr>( 70 LinalgTransforms::kLinalgTransformMarker); 71 72 if (!attr) { 73 // 1. Has no filter case and matchDisjunction is empty. 74 if (matchDisjunction.empty()) 75 return success(); 76 77 // 2. Has no filter but was expecting a filter. 78 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 79 diag << " does not have any filter from list: "; 80 interleaveComma(matchDisjunction, diag); 81 }); 82 } 83 84 // 4. Match explicit filter. 85 for (auto filter : matchDisjunction) 86 if (attr.getValue() == filter) 87 return success(); 88 89 // 5. Fail to match. 90 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 91 diag << " does not have any filter from list: "; 92 interleaveComma(matchDisjunction, diag); 93 }); 94 } 95 96 void mlir::linalg::LinalgTransformationFilter:: 97 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 98 Operation *op) const { 99 if (replacement.hasValue()) 100 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 101 rewriter.getStringAttr(replacement.getValue())); 102 else 103 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 104 rewriter.getContext())); 105 } 106 107 LinalgTilingOptions & 108 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 109 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 110 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 111 OpBuilder::InsertionGuard guard(b); 112 b.setInsertionPointToStart( 113 &op->getParentOfType<FuncOp>().getBody().front()); 114 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 115 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 116 return v; 117 })); 118 }; 119 return *this; 120 } 121 122 /// Try to compute a static bounding box for `operand` 123 /// Return success if either: 124 /// 1. The operand is already statically shaped, `result` is left unchanged. 125 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 126 /// created PadTensorOp. 127 /// Return failure if the operand cannot be padded to a static shape. 128 static LogicalResult padOperandToSmallestStaticBoundingBox( 129 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand, 130 const LinalgTilingOptions &options, Value &result) { 131 auto tensorType = operand.get().getType().cast<RankedTensorType>(); 132 // Already static shape, no need to pad. 133 if (tensorType.hasStaticShape()) 134 return success(); 135 auto subtensor = operand.get().getDefiningOp<SubTensorOp>(); 136 // Not a subtensor, cannot construct a static bounding box. 137 if (!subtensor) 138 return failure(); 139 SmallVector<int64_t> staticSizes; 140 staticSizes.reserve(tensorType.getRank()); 141 auto shapedOp = 142 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 143 for (auto size : shapedOp.getMixedSizes()) { 144 auto indexAttr = size.is<Attribute>() 145 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 146 : linalg::getSmallestBoundingIndex(size.get<Value>()); 147 // SmallestBoundingIndex must exist for all sizes. 148 // For now return an error if we can't find it. 149 if (!indexAttr) 150 return rewriter.notifyMatchFailure( 151 opToPad, "No constant bounding box can be found for padding"); 152 staticSizes.push_back(indexAttr.getInt()); 153 } 154 Value pad = options.paddingValueComputationFunction(rewriter, operand); 155 auto staticTensorType = 156 RankedTensorType::get(staticSizes, tensorType.getElementType()); 157 result = linalg::PadTensorOp::createPadHighOp( 158 staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter); 159 return success(); 160 } 161 162 // Try to create a static bounding box around each operand of `res.op`. 163 // If successful, `res.op` is rewritten in static form with padded operands. 164 // `res.op` is updated to the cloned static form of the op on success. 165 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 166 TiledLinalgOp &res, 167 const LinalgTilingOptions &options) { 168 LinalgOp opToPad = res.op; 169 Location loc = opToPad->getLoc(); 170 171 // If the op is fully static, it does not need padding. 172 // TODO: there are cases where we may still want to pad to larger sizes. 173 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 174 return v.getType().cast<RankedTensorType>().hasStaticShape(); 175 })) 176 return success(); 177 178 OpBuilder::InsertionGuard g(rewriter); 179 // Set IP after op because we also take the dims of the original output. 180 rewriter.setInsertionPointAfter(opToPad); 181 // Make a copy of the shaped operands and update it. 182 SmallVector<Value> newOperands; 183 newOperands.reserve(opToPad.getNumShapedOperands()); 184 for (OpOperand &operand : opToPad.getShapedOpOperands()) { 185 Value paddedOperand; 186 // If padding was requested but the shape cannot be bounded statically then 187 // the pattern fails to apply. 188 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand, 189 options, paddedOperand))) { 190 return failure(); 191 } 192 newOperands.push_back(paddedOperand ? paddedOperand : operand.get()); 193 } 194 195 // Clone `opToPad` to operate on the statically padded shapes. 196 auto resultTensorTypes = 197 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 198 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 199 newOperands.append(otherOperands.begin(), otherOperands.end()); 200 linalg::LinalgOp paddedOp = 201 opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 202 203 // Recover the subtensor out of the new static results. This keeps the 204 // original linalg op around because it uses the dims of the original results. 205 // This later folds away. 206 SmallVector<Value> paddedSubviewResults; 207 paddedSubviewResults.reserve(opToPad->getNumResults()); 208 SetVector<Operation *> newUsersOfOpToPad; 209 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 210 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 211 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 212 auto sizes = llvm::to_vector<4>(llvm::map_range( 213 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 214 auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d); 215 newUsersOfOpToPad.insert(dimOp); 216 return dimOp.getResult(); 217 })); 218 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 219 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 220 loc, std::get<1>(it), offsets, sizes, strides)); 221 } 222 // Replace the transient `opToPad` locally, except for uses that we just 223 // created for the purpose of extracting the dims. 224 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 225 return !newUsersOfOpToPad.contains(opOp.getOwner()); 226 }); 227 228 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 229 return success(); 230 } 231 232 /// Linalg base tiling pattern. 233 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 234 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 235 LinalgTransformationFilter filter, PatternBenefit benefit) 236 : RewritePattern(opName, benefit, context), filter(filter), 237 options(options) {} 238 239 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 240 MLIRContext *context, LinalgTilingOptions options, 241 LinalgTransformationFilter filter, PatternBenefit benefit) 242 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 243 options(options) {} 244 245 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 246 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 247 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 248 if (!linalgOp) 249 return failure(); 250 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 251 return failure(); 252 253 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 254 255 if (!res) 256 return failure(); 257 258 // Setup RAII guard to return properly. 259 LinalgOp tiledOp = res->op; 260 auto guard = llvm::make_scope_exit([&]() { 261 // Return relevant information to derived pattern. 262 result = *res; 263 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 264 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 265 if (tiledOp != res->op) 266 filter.replaceLinalgTransformationFilter(rewriter, res->op); 267 }); 268 269 // Consider padding on the fly only if the op has tensor semantics. 270 if (!options.paddingValueComputationFunction || 271 !linalgOp.hasTensorSemantics()) 272 return success(); 273 274 // Try to pad on the fly by rewriting res->op as a padded op. 275 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 276 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 277 return failure(); 278 } 279 280 // Do not perform replacement of `linalgOp`, let the derived patterns 281 // do this as they see fit, from the resulting TiledLinalgOp. 282 return success(); 283 } 284 285 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 286 if (tiledOp.loops.empty()) 287 return tiledOp.op.getOperation()->getResults(); 288 return tiledOp.loops.front()->getResults(); 289 } 290 291 static ValueRange 292 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 293 if (tiledAndFusedOp.fusedLoops.empty()) 294 return tiledAndFusedOp.op.getOperation()->getResults(); 295 return tiledAndFusedOp.fusedLoops.front()->getResults(); 296 } 297 298 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 299 StringRef opName, MLIRContext *context, 300 const LinalgDependenceGraph &dependenceGraph, 301 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 302 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 303 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 304 : RewritePattern(opName, benefit, context, {}), 305 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 306 fusionOptions(fusionOptions), filter(filter), 307 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 308 309 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 310 Operation *op, PatternRewriter &rewriter) const { 311 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 312 // TODO: remove hasIndexSemantics check once index ops are supported. 313 if (!linalgOp || linalgOp.hasIndexSemantics()) 314 return failure(); 315 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 316 return failure(); 317 318 DenseSet<Operation *> producers; 319 producers.insert(linalgOp); 320 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 321 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 322 // When looking at dependences into, indexingOp is always OpOperand. We 323 // could assert, but continue if this is not the case. 324 if (!operandNumber) 325 continue; 326 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 327 continue; 328 if (isa<LinalgOp>(dependence.getDependentOp())) 329 producers.insert(dependence.getDependentOp()); 330 } 331 332 SmallVector<LinalgOp, 1> fusionOps; 333 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 334 ++it) { 335 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 336 if (producerLinalgOp && producers.count(producerLinalgOp)) 337 fusionOps.push_back(producerLinalgOp); 338 } 339 fusionOps.push_back(linalgOp); 340 341 SmallVector<Value, 4> tileSizes = 342 tilingOptions.tileSizeComputationFunction(rewriter, op); 343 LinalgTilingOptions instanceTilingOptions = tilingOptions; 344 instanceTilingOptions.setTileSizes(tileSizes); 345 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 346 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 347 if (!tiledAndFusedOps) 348 return failure(); 349 350 // Tile the unfused loops; 351 SmallVector<Value, 4> unfusedLoopTileSizes; 352 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 353 for (auto tileSize : enumerate(tileSizes)) { 354 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 355 unfusedLoopTileSizes.push_back(zero); 356 else 357 unfusedLoopTileSizes.push_back(tileSize.value()); 358 } 359 // Tile the loop only if there is a non-zero tile size. 360 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 361 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 362 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 363 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 364 return cst.getValue() != 0; 365 return true; 366 })) { 367 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 368 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 369 Optional<TiledLinalgOp> unfusedTiledOp = 370 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 371 if (!unfusedTiledOp) 372 return failure(); 373 rewriter.replaceOp(tiledAndFusedOps->op, 374 getTiledOpResult(unfusedTiledOp.getValue())); 375 tiledAndFusedOps->op = unfusedTiledOp->op; 376 } 377 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 378 379 filter.replaceLinalgTransformationFilter(rewriter, 380 tiledAndFusedOps->op.getOperation()); 381 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 382 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 383 fusedOp.getOperation()); 384 } 385 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 386 originalOpMarker.replaceLinalgTransformationFilter( 387 rewriter, origProducerOp.getOperation()); 388 } 389 rewriter.updateRootInPlace(op, [&]() { 390 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 391 }); 392 return success(); 393 } 394 395 /// Linalg generic interchange pattern. 396 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 397 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 398 LinalgTransformationFilter filter, PatternBenefit benefit) 399 : OpRewritePattern(context, benefit), filter(filter), 400 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 401 402 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 403 GenericOp genericOp, PatternRewriter &rewriter) const { 404 if (failed(filter.checkAndNotify(rewriter, genericOp))) 405 return failure(); 406 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 407 return failure(); 408 409 // TODO: figure out how this interplays with named ops. In particular this 410 // should break the named op property. 411 rewriter.updateRootInPlace(genericOp, [&]() { 412 interchangeGenericOp(rewriter, genericOp, interchangeVector); 413 // New filter if specified. 414 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 415 }); 416 return success(); 417 } 418 419 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 420 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 421 LinalgTransformationFilter filter, PatternBenefit benefit) 422 : RewritePattern(opName, benefit, context, {}), filter(filter), 423 options(options) {} 424 425 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 426 Operation *op, PatternRewriter &rewriter) const { 427 if (failed(filter.checkAndNotify(rewriter, op))) 428 return failure(); 429 if (failed(promoteSubviewsPrecondition(op, options))) 430 return failure(); 431 432 // TODO: We cannot use root update here. This pattern is creating other ops, 433 // so if the promotion fails, those need to be cleaned up, which doesnt seem 434 // to be happening here. So to fail properly, we should be cloning the op and 435 // deleting the previous op. This needs more investigation. 436 rewriter.startRootUpdate(op); 437 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 438 if (!promotedOp) { 439 rewriter.cancelRootUpdate(op); 440 return op->emitError("subview promotion failed"); 441 } 442 rewriter.finalizeRootUpdate(op); 443 filter.replaceLinalgTransformationFilter(rewriter, op); 444 return success(); 445 } 446 447 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 448 MLIRContext *context, LinalgTransformationFilter filter, 449 PatternBenefit benefit) 450 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 451 452 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 453 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 454 PatternBenefit benefit) 455 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 456 457 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 458 Operation *op, PatternRewriter &rewriter) const { 459 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 460 if (!linalgOp) 461 return failure(); 462 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 463 return failure(); 464 SmallVector<Value> newResults; 465 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 466 return failure(); 467 if (!newResults.empty()) 468 rewriter.replaceOp(op, newResults); 469 else 470 rewriter.eraseOp(op); 471 return success(); 472 } 473 474 LogicalResult mlir::linalg::applyStagedPatterns( 475 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 476 const FrozenRewritePatternSet &stage2Patterns, 477 function_ref<LogicalResult(Operation *)> stage3Lambda) { 478 unsigned iteration = 0; 479 (void)iteration; 480 for (const auto &patterns : stage1Patterns) { 481 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 482 << *op); 483 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 484 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 485 return failure(); 486 } 487 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 488 << *op); 489 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 490 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 491 return failure(); 492 } 493 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 494 << *op); 495 if (stage3Lambda) { 496 if (failed(stage3Lambda(op))) 497 return failure(); 498 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 499 << *op); 500 } 501 } 502 return success(); 503 } 504 505 /// Traverse the `dims` and substitute known min or max expressions returned by 506 /// the lambda |getMinMaxExpr|. 507 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 508 SmallVectorImpl<Value> &symbols, 509 GetMinMaxExprFn getMinMaxExpr) { 510 auto exprs = llvm::to_vector<4>(map.getResults()); 511 for (AffineExpr &expr : exprs) { 512 bool substituted = true; 513 while (substituted) { 514 substituted = false; 515 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 516 Value dim = dims[dimIdx]; 517 auto minMax = getMinMaxExpr(dim, dims, symbols); 518 if (!minMax) 519 continue; 520 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 521 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 522 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 523 // Substitute occurrences of `dimExpr` by either the min expression or 524 // the max expression depending on whether the value is used with a 525 // positive or negative coefficient. 526 AffineExpr substitutedExpr = 527 substWithMin(expr, dimExpr, minMax->first, minMax->second); 528 LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n"); 529 substituted = (substitutedExpr != expr); 530 expr = substitutedExpr; 531 } 532 } 533 534 // Cleanup and simplify the results. 535 // This needs to happen outside of the loop iterating on dims.size() since 536 // it modifies dims. 537 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 538 operands.append(symbols.begin(), symbols.end()); 539 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 540 exprs.front().getContext()); 541 542 LLVM_DEBUG({ 543 DBGS() << "Map to simplify: " << map << "\n"; 544 DBGS() << "Operands:\n"; 545 for (Value v : operands) 546 DBGS() << v << "\n"; 547 }); 548 549 // Pull in affine.apply operations and compose them fully into the 550 // result. 551 fullyComposeAffineMapAndOperands(&map, &operands); 552 canonicalizeMapAndOperands(&map, &operands); 553 map = simplifyAffineMap(map); 554 // Assign the results. 555 exprs.assign(map.getResults().begin(), map.getResults().end()); 556 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 557 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 558 559 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 560 } 561 562 assert(!exprs.empty() && "Unexpected empty exprs"); 563 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 564 } 565 566 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute 567 /// dimensions with known range by new expressions involving the min or max 568 /// expression: 569 /// - If the AffineDimExpr mapped to a known value has a positive sign, it 570 /// is replaced by the min expression. 571 /// - If the AffineDimExpr mapped to a known value has a negative sign, it is 572 /// replaced by the max expression. 573 /// All known values are iteratively replaced. 574 /// This is used as an intermediate step in computing bounding boxes and 575 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have 576 /// positive values (positive orthant assumptions). 577 /// Return a new AffineMap, dims and symbols that have been canonicalized and 578 /// simplified. 579 AffineMapAndOperands 580 mlir::linalg::substituteMin(AffineMinOp affineMinOp, 581 GetMinMaxExprFn getMinMaxExpr) { 582 AffineMapAndOperands res{affineMinOp.getAffineMap(), 583 SmallVector<Value>(affineMinOp.getDimOperands()), 584 SmallVector<Value>(affineMinOp.getSymbolOperands())}; 585 res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, 586 getMinMaxExpr); 587 return res; 588 } 589 590 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( 591 AffineMinOp minOp, PatternRewriter &rewriter) const { 592 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 593 << "\n"); 594 595 auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn); 596 AffineMap map = affineMapAndOperands.map; 597 598 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 599 600 // Check whether any of the expressions, when subtracted from all other 601 // expressions, produces only >= 0 constants. If so, it is the min. 602 for (auto e : minOp.getAffineMap().getResults()) { 603 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 604 if (!e.isSymbolicOrConstant()) 605 continue; 606 607 auto isNonPositive = [](AffineExpr e) { 608 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 609 return cst.getValue() < 0; 610 return true; 611 }; 612 613 // Build the subMap and check everything is statically known to be 614 // positive. 615 SmallVector<AffineExpr, 4> subExprs; 616 subExprs.reserve(map.getNumResults()); 617 for (auto ee : map.getResults()) 618 subExprs.push_back(ee - e); 619 MLIRContext *ctx = minOp.getContext(); 620 AffineMap subMap = simplifyAffineMap( 621 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 622 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 623 if (llvm::any_of(subMap.getResults(), isNonPositive)) 624 continue; 625 626 // Static min found. 627 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 628 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 629 } else { 630 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 631 SmallVector<Value> resultOperands = affineMapAndOperands.dims; 632 llvm::append_range(resultOperands, affineMapAndOperands.symbols); 633 canonicalizeMapAndOperands(&resultMap, &resultOperands); 634 resultMap = simplifyAffineMap(resultMap); 635 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 636 resultOperands); 637 } 638 return success(); 639 } 640 641 return failure(); 642 } 643