1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 50 Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 LogicalResult 55 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 56 Operation *op) const { 57 auto attr = op->template getAttrOfType<StringAttr>( 58 LinalgTransforms::kLinalgTransformMarker); 59 60 if (!attr) { 61 // 1. Has no marker case and matchDisjunction is empty. 62 if (matchDisjunction.empty()) 63 return success(); 64 65 // 2. Has no marker but was expecting a marker. 66 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 67 diag << " does not have any marker from list: "; 68 interleaveComma(matchDisjunction, diag); 69 }); 70 } 71 72 // 4. Match explicit marker. 73 for (auto marker : matchDisjunction) 74 if (attr.getValue() == marker) 75 return success(); 76 77 // 5. Fail to match. 78 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 79 diag << " does not have any marker from list: "; 80 interleaveComma(matchDisjunction, diag); 81 }); 82 } 83 84 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 85 Operation *op) const { 86 if (replacement.hasValue()) 87 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 88 rewriter.getStringAttr(replacement.getValue())); 89 else 90 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 91 rewriter.getContext())); 92 } 93 94 LinalgTilingOptions & 95 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 96 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 97 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 98 OpBuilder::InsertionGuard guard(b); 99 b.setInsertionPointToStart( 100 &op->getParentOfType<FuncOp>().getBody().front()); 101 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 102 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 103 return v; 104 })); 105 }; 106 return *this; 107 } 108 109 /// Try to compute a static bounding box for `operand` 110 /// Return success if either: 111 /// 1. The operand is already statically shaped, `result` is left unchanged. 112 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 113 /// created SimplePadOp. 114 /// Return failure if the operand cannot be padded to a static shape. 115 static LogicalResult padOperandToSmallestStaticBoundingBox( 116 PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, 117 const LinalgTilingOptions &options, Value &result) { 118 auto tensorType = operand.getType().cast<RankedTensorType>(); 119 // Already static shape, no need to pad. 120 if (tensorType.hasStaticShape()) 121 return success(); 122 auto subtensor = operand.getDefiningOp<SubTensorOp>(); 123 // Not a subtensor, cannot construct a static bounding box. 124 if (!subtensor) 125 return failure(); 126 SmallVector<int64_t> staticSizes; 127 staticSizes.reserve(tensorType.getRank()); 128 auto shapedOp = 129 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 130 for (auto size : shapedOp.getMixedSizes()) { 131 auto indexAttr = size.is<Attribute>() 132 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 133 : linalg::getSmallestBoundingIndex(size.get<Value>()); 134 // SmallestBoundingIndex must exist for all sizes. 135 // For now return an error if we can't find it. 136 if (!indexAttr) 137 return rewriter.notifyMatchFailure( 138 opToPad, "No constant bounding box can be found for padding"); 139 staticSizes.push_back(indexAttr.getInt()); 140 } 141 Value pad = options.paddingValueComputationFunction(rewriter, opToPad); 142 auto staticTensorType = 143 RankedTensorType::get(staticSizes, tensorType.getElementType()); 144 result = rewriter.create<linalg::SimplePadOp>(opToPad->getLoc(), 145 staticTensorType, operand, pad); 146 return success(); 147 } 148 149 // Try to create a static bounding box around each operand of `res.op`. 150 // If successful, `res.op` is rewritten in static form with padded operands. 151 // `res.op` is updated to the cloned static form of the op on success. 152 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 153 TiledLinalgOp &res, 154 const LinalgTilingOptions &options) { 155 LinalgOp opToPad = res.op; 156 Location loc = opToPad->getLoc(); 157 158 // If the op is fully static, it does not need padding. 159 // TODO: there are cases where we may still want to pad to larger sizes. 160 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 161 return v.getType().cast<RankedTensorType>().hasStaticShape(); 162 })) 163 return success(); 164 165 OpBuilder::InsertionGuard g(rewriter); 166 // Set IP after op because we also take the dims of the original output. 167 rewriter.setInsertionPointAfter(opToPad); 168 // Make a copy of the shaped operands and update it. 169 SmallVector<Value> operands = opToPad.getShapedOperands(); 170 for (Value &v : operands) { 171 Value paddedOperand; 172 // If padding was requested but the shape cannot be bounded statically then 173 // the pattern fails to apply. 174 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, 175 options, paddedOperand))) { 176 return failure(); 177 } 178 // Update v if we indeed got a padded operand. 179 v = paddedOperand ? paddedOperand : v; 180 } 181 182 // Clone `opToPad` to operate on the statically padded shapes. 183 auto resultTensorTypes = 184 ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); 185 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 186 operands.append(otherOperands.begin(), otherOperands.end()); 187 linalg::LinalgOp paddedOp = 188 opToPad.clone(rewriter, loc, resultTensorTypes, operands); 189 190 // Recover the subtensor out of the new static results. This keeps the 191 // original linalg op around because it uses the dims of the original results. 192 // This later folds away. 193 SmallVector<Value> paddedSubviewResults; 194 paddedSubviewResults.reserve(opToPad->getNumResults()); 195 llvm::SetVector<Operation *> newUsersOfOpToPad; 196 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 197 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 198 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 199 auto sizes = llvm::to_vector<4>(llvm::map_range( 200 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 201 auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d); 202 newUsersOfOpToPad.insert(dimOp); 203 return dimOp.getResult(); 204 })); 205 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 206 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 207 loc, std::get<1>(it), offsets, sizes, strides)); 208 } 209 // Replace the transient `opToPad` locally, except for uses that we just 210 // created for the purpose of extracting the dims. 211 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 212 return !newUsersOfOpToPad.contains(opOp.getOwner()); 213 }); 214 215 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 216 return success(); 217 } 218 219 /// Linalg base tiling pattern. 220 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 221 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 222 LinalgMarker marker, PatternBenefit benefit) 223 : RewritePattern(opName, {}, benefit, context), marker(marker), 224 options(options) {} 225 226 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 227 LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) 228 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), 229 options(options) {} 230 231 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 232 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 233 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 234 if (!linalgOp) 235 return failure(); 236 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 237 return failure(); 238 239 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 240 241 if (!res) 242 return failure(); 243 244 // Setup RAII guard to return properly. 245 bool succeeded = true; 246 LinalgOp tiledOp = res->op; 247 auto guard = llvm::make_scope_exit([&]() { 248 if (!succeeded) 249 return; 250 // Return relevant information to derived pattern. 251 result = *res; 252 // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary. 253 marker.replaceLinalgMarker(rewriter, tiledOp); 254 if (tiledOp != res->op) 255 marker.replaceLinalgMarker(rewriter, res->op); 256 }); 257 258 // Consider padding on the fly only if the op has tensor semantics. 259 if (!options.paddingValueComputationFunction || 260 !linalgOp.hasTensorSemantics()) 261 return success(); 262 263 // Try to pad on the fly by rewriting res->op as a padded op. 264 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 265 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 266 succeeded = false; 267 return failure(); 268 } 269 270 // Do not perform replacement of `linalgOp`, let the derived patterns 271 // do this as they see fit, from the resulting TiledLinalgOp. 272 return success(); 273 } 274 275 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 276 StringRef opName, MLIRContext *context, 277 const LinalgDependenceGraph &dependenceGraph, 278 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 279 LinalgMarker marker, LinalgMarker fusedOpMarker, 280 LinalgMarker originalOpMarker, PatternBenefit benefit) 281 : RewritePattern(opName, {}, benefit, context), 282 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 283 fusionOptions(fusionOptions), marker(marker), 284 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 285 286 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 287 Operation *op, PatternRewriter &rewriter) const { 288 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 289 if (!linalgOp) 290 return failure(); 291 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 292 return failure(); 293 if (!linalgOp.hasBufferSemantics()) 294 return failure(); 295 296 DenseSet<Operation *> producers; 297 producers.insert(linalgOp); 298 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 299 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 300 // When looking at dependences into, indexingOp is always OpOperand. We 301 // could assert, but continue if this is not the case. 302 if (!operandNumber) 303 continue; 304 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 305 continue; 306 if (isa<LinalgOp>(dependence.getDependentOp())) 307 producers.insert(dependence.getDependentOp()); 308 } 309 310 SmallVector<LinalgOp, 1> fusionOps; 311 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 312 ++it) { 313 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 314 if (producerLinalgOp && producers.count(producerLinalgOp)) 315 fusionOps.push_back(producerLinalgOp); 316 } 317 fusionOps.push_back(linalgOp); 318 319 SmallVector<Value, 4> tileSizes = 320 tilingOptions.tileSizeComputationFunction(rewriter, op); 321 LinalgTilingOptions instanceTilingOptions = tilingOptions; 322 instanceTilingOptions.setTileSizes(tileSizes); 323 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 324 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 325 if (!tiledAndFusedOps) 326 return failure(); 327 328 // Tile the unfused loops; 329 SmallVector<Value, 4> unfusedLoopTileSizes; 330 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 331 for (auto tileSize : enumerate(tileSizes)) { 332 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 333 unfusedLoopTileSizes.push_back(zero); 334 else 335 unfusedLoopTileSizes.push_back(tileSize.value()); 336 } 337 // Tile the loop only if there is a non-zero tile size. 338 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 339 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 340 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 341 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 342 return cst.getValue() != 0; 343 return true; 344 })) { 345 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 346 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 347 Optional<TiledLinalgOp> unfusedTiledOp = 348 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 349 if (!unfusedTiledOp) 350 return failure(); 351 rewriter.eraseOp(tiledAndFusedOps->op); 352 tiledAndFusedOps->op = unfusedTiledOp->op; 353 } 354 355 marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); 356 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 357 fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); 358 } 359 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 360 originalOpMarker.replaceLinalgMarker(rewriter, 361 origProducerOp.getOperation()); 362 } 363 rewriter.updateRootInPlace( 364 op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); 365 return success(); 366 } 367 368 /// Linalg base interchange pattern. 369 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 370 StringRef opName, MLIRContext *context, 371 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 372 PatternBenefit benefit) 373 : RewritePattern(opName, {}, benefit, context), marker(marker), 374 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 375 376 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 377 Operation *op, PatternRewriter &rewriter) const { 378 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 379 if (!linalgOp) 380 return failure(); 381 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 382 return failure(); 383 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 384 return failure(); 385 386 // TODO: figure out how this interplays with named ops. In particular this 387 // should break the named op property. 388 rewriter.updateRootInPlace(op, [&]() { 389 interchange(linalgOp, interchangeVector); 390 // New marker if specified. 391 marker.replaceLinalgMarker(rewriter, op); 392 }); 393 return success(); 394 } 395 396 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 397 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 398 LinalgMarker marker, PatternBenefit benefit) 399 : RewritePattern(opName, {}, benefit, context), marker(marker), 400 options(options) {} 401 402 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 403 Operation *op, PatternRewriter &rewriter) const { 404 if (failed(marker.checkAndNotify(rewriter, op))) 405 return failure(); 406 if (failed(promoteSubviewsPrecondition(op, options))) 407 return failure(); 408 409 // TODO: We cannot use root update here. This pattern is creating other ops, 410 // so if the promotion fails, those need to be cleaned up, which doesnt seem 411 // to be happening here. So to fail properly, we should be cloning the op and 412 // deleting the previous op. This needs more investigation. 413 rewriter.startRootUpdate(op); 414 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 415 if (!promotedOp) { 416 rewriter.cancelRootUpdate(op); 417 return op->emitError("subview promotion failed"); 418 } 419 rewriter.finalizeRootUpdate(op); 420 marker.replaceLinalgMarker(rewriter, op); 421 return success(); 422 } 423 424 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 425 StringRef opName, MLIRContext *context, LinalgMarker marker, 426 PatternBenefit benefit) 427 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 428 429 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 430 Operation *op, PatternRewriter &rewriter) const { 431 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 432 if (!linalgOp) 433 return failure(); 434 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 435 return failure(); 436 if (failed(vectorizeLinalgOpPrecondition(op))) 437 return failure(); 438 vectorizeLinalgOp(rewriter, op); 439 rewriter.eraseOp(op); 440 return success(); 441 } 442 443 LogicalResult mlir::linalg::applyStagedPatterns( 444 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 445 const FrozenRewritePatternList &stage2Patterns, 446 function_ref<LogicalResult(Operation *)> stage3Lambda) { 447 unsigned iteration = 0; 448 (void)iteration; 449 for (const auto &patterns : stage1Patterns) { 450 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 451 << *op); 452 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 453 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 454 return failure(); 455 } 456 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 457 << *op); 458 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 459 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 460 return failure(); 461 } 462 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 463 << *op); 464 if (stage3Lambda) { 465 if (failed(stage3Lambda(op))) 466 return failure(); 467 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 468 << *op); 469 } 470 } 471 return success(); 472 } 473 474 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 475 /// `ubVal` to `dims` and `stepVal` to `symbols`. 476 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 477 /// with positions matching the newly appended values. Substitute occurrences of 478 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 479 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 480 /// the induction variable is used with a positive or negative coefficient. 481 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 482 Value lbVal, Value ubVal, Value stepVal, 483 SmallVectorImpl<Value> &dims, 484 SmallVectorImpl<Value> &symbols) { 485 MLIRContext *ctx = lbVal.getContext(); 486 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 487 dims.push_back(lbVal); 488 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 489 dims.push_back(ubVal); 490 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 491 symbols.push_back(stepVal); 492 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 493 AffineExpr ee = substWithMin(expr, dimExpr, lb, 494 lb + step * ((ub - 1) - lb).floorDiv(step)); 495 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 496 return ee; 497 } 498 499 /// Traverse the `dims` and substitute known min or max expressions in place of 500 /// induction variables in `exprs`. 501 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 502 SmallVectorImpl<Value> &symbols) { 503 auto exprs = llvm::to_vector<4>(map.getResults()); 504 for (AffineExpr &expr : exprs) { 505 bool substituted = true; 506 while (substituted) { 507 substituted = false; 508 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 509 Value dim = dims[dimIdx]; 510 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 511 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 512 AffineExpr substitutedExpr; 513 if (auto forOp = scf::getForInductionVarOwner(dim)) 514 substitutedExpr = substituteLoopInExpr( 515 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 516 forOp.step(), dims, symbols); 517 518 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 519 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 520 ++idx) 521 substitutedExpr = substituteLoopInExpr( 522 expr, dimExpr, parallelForOp.lowerBound()[idx], 523 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 524 dims, symbols); 525 526 if (!substitutedExpr) 527 continue; 528 529 substituted = (substitutedExpr != expr); 530 expr = substitutedExpr; 531 } 532 } 533 534 // Cleanup and simplify the results. 535 // This needs to happen outside of the loop iterating on dims.size() since 536 // it modifies dims. 537 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 538 operands.append(symbols.begin(), symbols.end()); 539 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 540 exprs.front().getContext()); 541 542 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 543 544 // Pull in affine.apply operations and compose them fully into the 545 // result. 546 fullyComposeAffineMapAndOperands(&map, &operands); 547 canonicalizeMapAndOperands(&map, &operands); 548 map = simplifyAffineMap(map); 549 // Assign the results. 550 exprs.assign(map.getResults().begin(), map.getResults().end()); 551 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 552 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 553 554 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 555 } 556 557 assert(!exprs.empty() && "Unexpected empty exprs"); 558 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 559 } 560 561 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 562 AffineMinOp minOp, PatternRewriter &rewriter) const { 563 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 564 << "\n"); 565 566 SmallVector<Value, 4> dims(minOp.getDimOperands()), 567 symbols(minOp.getSymbolOperands()); 568 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 569 570 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 571 572 // Check whether any of the expressions, when subtracted from all other 573 // expressions, produces only >= 0 constants. If so, it is the min. 574 for (auto e : minOp.getAffineMap().getResults()) { 575 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 576 if (!e.isSymbolicOrConstant()) 577 continue; 578 579 auto isNonPositive = [](AffineExpr e) { 580 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 581 return cst.getValue() < 0; 582 return true; 583 }; 584 585 // Build the subMap and check everything is statically known to be 586 // positive. 587 SmallVector<AffineExpr, 4> subExprs; 588 subExprs.reserve(map.getNumResults()); 589 for (auto ee : map.getResults()) 590 subExprs.push_back(ee - e); 591 MLIRContext *ctx = minOp.getContext(); 592 AffineMap subMap = simplifyAffineMap( 593 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 594 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 595 if (llvm::any_of(subMap.getResults(), isNonPositive)) 596 continue; 597 598 // Static min found. 599 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 600 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 601 } else { 602 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 603 SmallVector<Value, 4> resultOperands = dims; 604 resultOperands.append(symbols.begin(), symbols.end()); 605 canonicalizeMapAndOperands(&resultMap, &resultOperands); 606 resultMap = simplifyAffineMap(resultMap); 607 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 608 resultOperands); 609 } 610 return success(); 611 } 612 613 return failure(); 614 } 615