1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : LinalgTransformationFilter([](Operation *) { return success(); }, 52 matchDisjunction, replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filter(f), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) {} 60 61 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 62 PatternRewriter &rewriter, Operation *op) const { 63 if (filter && failed(filter(op))) 64 return failure(); 65 66 auto attr = op->template getAttrOfType<StringAttr>( 67 LinalgTransforms::kLinalgTransformMarker); 68 69 if (!attr) { 70 // 1. Has no marker case and matchDisjunction is empty. 71 if (matchDisjunction.empty()) 72 return success(); 73 74 // 2. Has no marker but was expecting a marker. 75 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 76 diag << " does not have any marker from list: "; 77 interleaveComma(matchDisjunction, diag); 78 }); 79 } 80 81 // 4. Match explicit marker. 82 for (auto marker : matchDisjunction) 83 if (attr.getValue() == marker) 84 return success(); 85 86 // 5. Fail to match. 87 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 88 diag << " does not have any marker from list: "; 89 interleaveComma(matchDisjunction, diag); 90 }); 91 } 92 93 void mlir::linalg::LinalgTransformationFilter:: 94 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 95 Operation *op) const { 96 if (replacement.hasValue()) 97 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getStringAttr(replacement.getValue())); 99 else 100 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 101 rewriter.getContext())); 102 } 103 104 LinalgTilingOptions & 105 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 106 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 107 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 108 OpBuilder::InsertionGuard guard(b); 109 b.setInsertionPointToStart( 110 &op->getParentOfType<FuncOp>().getBody().front()); 111 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 112 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 113 return v; 114 })); 115 }; 116 return *this; 117 } 118 119 /// Try to compute a static bounding box for `operand` 120 /// Return success if either: 121 /// 1. The operand is already statically shaped, `result` is left unchanged. 122 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 123 /// created PadTensorOp. 124 /// Return failure if the operand cannot be padded to a static shape. 125 static LogicalResult padOperandToSmallestStaticBoundingBox( 126 PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, 127 const LinalgTilingOptions &options, Value &result) { 128 auto tensorType = operand.getType().cast<RankedTensorType>(); 129 // Already static shape, no need to pad. 130 if (tensorType.hasStaticShape()) 131 return success(); 132 auto subtensor = operand.getDefiningOp<SubTensorOp>(); 133 // Not a subtensor, cannot construct a static bounding box. 134 if (!subtensor) 135 return failure(); 136 SmallVector<int64_t> staticSizes; 137 staticSizes.reserve(tensorType.getRank()); 138 auto shapedOp = 139 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = options.paddingValueComputationFunction(rewriter, opToPad); 152 auto staticTensorType = 153 RankedTensorType::get(staticSizes, tensorType.getElementType()); 154 result = linalg::PadTensorOp::createPadHighOp(staticTensorType, operand, pad, 155 opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 // Try to create a static bounding box around each operand of `res.op`. 160 // If successful, `res.op` is rewritten in static form with padded operands. 161 // `res.op` is updated to the cloned static form of the op on success. 162 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 163 TiledLinalgOp &res, 164 const LinalgTilingOptions &options) { 165 LinalgOp opToPad = res.op; 166 Location loc = opToPad->getLoc(); 167 168 // If the op is fully static, it does not need padding. 169 // TODO: there are cases where we may still want to pad to larger sizes. 170 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 171 return v.getType().cast<RankedTensorType>().hasStaticShape(); 172 })) 173 return success(); 174 175 OpBuilder::InsertionGuard g(rewriter); 176 // Set IP after op because we also take the dims of the original output. 177 rewriter.setInsertionPointAfter(opToPad); 178 // Make a copy of the shaped operands and update it. 179 SmallVector<Value> operands = opToPad.getShapedOperands(); 180 for (Value &v : operands) { 181 Value paddedOperand; 182 // If padding was requested but the shape cannot be bounded statically then 183 // the pattern fails to apply. 184 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, 185 options, paddedOperand))) { 186 return failure(); 187 } 188 // Update v if we indeed got a padded operand. 189 v = paddedOperand ? paddedOperand : v; 190 } 191 192 // Clone `opToPad` to operate on the statically padded shapes. 193 auto resultTensorTypes = 194 ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); 195 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 196 operands.append(otherOperands.begin(), otherOperands.end()); 197 linalg::LinalgOp paddedOp = 198 opToPad.clone(rewriter, loc, resultTensorTypes, operands); 199 200 // Recover the subtensor out of the new static results. This keeps the 201 // original linalg op around because it uses the dims of the original results. 202 // This later folds away. 203 SmallVector<Value> paddedSubviewResults; 204 paddedSubviewResults.reserve(opToPad->getNumResults()); 205 llvm::SetVector<Operation *> newUsersOfOpToPad; 206 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 207 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 208 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 209 auto sizes = llvm::to_vector<4>(llvm::map_range( 210 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 211 auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d); 212 newUsersOfOpToPad.insert(dimOp); 213 return dimOp.getResult(); 214 })); 215 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 216 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 217 loc, std::get<1>(it), offsets, sizes, strides)); 218 } 219 // Replace the transient `opToPad` locally, except for uses that we just 220 // created for the purpose of extracting the dims. 221 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 222 return !newUsersOfOpToPad.contains(opOp.getOwner()); 223 }); 224 225 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 226 return success(); 227 } 228 229 /// Linalg base tiling pattern. 230 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 231 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 232 LinalgTransformationFilter marker, PatternBenefit benefit) 233 : RewritePattern(opName, {}, benefit, context), marker(marker), 234 options(options) {} 235 236 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 237 LinalgTilingOptions options, LinalgTransformationFilter marker, 238 PatternBenefit benefit) 239 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), 240 options(options) {} 241 242 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 243 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 244 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 245 if (!linalgOp) 246 return failure(); 247 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 248 return failure(); 249 250 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 251 252 if (!res) 253 return failure(); 254 255 // Setup RAII guard to return properly. 256 bool succeeded = true; 257 LinalgOp tiledOp = res->op; 258 auto guard = llvm::make_scope_exit([&]() { 259 if (!succeeded) 260 return; 261 // Return relevant information to derived pattern. 262 result = *res; 263 // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary. 264 marker.replaceLinalgTransformationFilter(rewriter, tiledOp); 265 if (tiledOp != res->op) 266 marker.replaceLinalgTransformationFilter(rewriter, res->op); 267 }); 268 269 // Consider padding on the fly only if the op has tensor semantics. 270 if (!options.paddingValueComputationFunction || 271 !linalgOp.hasTensorSemantics()) 272 return success(); 273 274 // Try to pad on the fly by rewriting res->op as a padded op. 275 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 276 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 277 succeeded = false; 278 return failure(); 279 } 280 281 // Do not perform replacement of `linalgOp`, let the derived patterns 282 // do this as they see fit, from the resulting TiledLinalgOp. 283 return success(); 284 } 285 286 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 287 StringRef opName, MLIRContext *context, 288 const LinalgDependenceGraph &dependenceGraph, 289 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 290 LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker, 291 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 292 : RewritePattern(opName, {}, benefit, context), 293 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 294 fusionOptions(fusionOptions), marker(marker), 295 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 296 297 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 298 Operation *op, PatternRewriter &rewriter) const { 299 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 300 if (!linalgOp) 301 return failure(); 302 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 303 return failure(); 304 if (!linalgOp.hasBufferSemantics()) 305 return failure(); 306 307 DenseSet<Operation *> producers; 308 producers.insert(linalgOp); 309 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 310 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 311 // When looking at dependences into, indexingOp is always OpOperand. We 312 // could assert, but continue if this is not the case. 313 if (!operandNumber) 314 continue; 315 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 316 continue; 317 if (isa<LinalgOp>(dependence.getDependentOp())) 318 producers.insert(dependence.getDependentOp()); 319 } 320 321 SmallVector<LinalgOp, 1> fusionOps; 322 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 323 ++it) { 324 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 325 if (producerLinalgOp && producers.count(producerLinalgOp)) 326 fusionOps.push_back(producerLinalgOp); 327 } 328 fusionOps.push_back(linalgOp); 329 330 SmallVector<Value, 4> tileSizes = 331 tilingOptions.tileSizeComputationFunction(rewriter, op); 332 LinalgTilingOptions instanceTilingOptions = tilingOptions; 333 instanceTilingOptions.setTileSizes(tileSizes); 334 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 335 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 336 if (!tiledAndFusedOps) 337 return failure(); 338 339 // Tile the unfused loops; 340 SmallVector<Value, 4> unfusedLoopTileSizes; 341 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 342 for (auto tileSize : enumerate(tileSizes)) { 343 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 344 unfusedLoopTileSizes.push_back(zero); 345 else 346 unfusedLoopTileSizes.push_back(tileSize.value()); 347 } 348 // Tile the loop only if there is a non-zero tile size. 349 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 350 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 351 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 352 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 353 return cst.getValue() != 0; 354 return true; 355 })) { 356 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 357 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 358 Optional<TiledLinalgOp> unfusedTiledOp = 359 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 360 if (!unfusedTiledOp) 361 return failure(); 362 rewriter.eraseOp(tiledAndFusedOps->op); 363 tiledAndFusedOps->op = unfusedTiledOp->op; 364 } 365 366 marker.replaceLinalgTransformationFilter(rewriter, 367 tiledAndFusedOps->op.getOperation()); 368 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 369 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 370 fusedOp.getOperation()); 371 } 372 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 373 originalOpMarker.replaceLinalgTransformationFilter( 374 rewriter, origProducerOp.getOperation()); 375 } 376 rewriter.updateRootInPlace(op, [&]() { 377 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 378 }); 379 return success(); 380 } 381 382 /// Linalg base interchange pattern. 383 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 384 StringRef opName, MLIRContext *context, 385 ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker, 386 PatternBenefit benefit) 387 : RewritePattern(opName, {}, benefit, context), marker(marker), 388 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 389 390 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 391 Operation *op, PatternRewriter &rewriter) const { 392 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 393 if (!linalgOp) 394 return failure(); 395 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 396 return failure(); 397 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 398 return failure(); 399 400 // TODO: figure out how this interplays with named ops. In particular this 401 // should break the named op property. 402 rewriter.updateRootInPlace(op, [&]() { 403 interchange(linalgOp, interchangeVector); 404 // New marker if specified. 405 marker.replaceLinalgTransformationFilter(rewriter, op); 406 }); 407 return success(); 408 } 409 410 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 411 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 412 LinalgTransformationFilter marker, PatternBenefit benefit) 413 : RewritePattern(opName, {}, benefit, context), marker(marker), 414 options(options) {} 415 416 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 417 Operation *op, PatternRewriter &rewriter) const { 418 if (failed(marker.checkAndNotify(rewriter, op))) 419 return failure(); 420 if (failed(promoteSubviewsPrecondition(op, options))) 421 return failure(); 422 423 // TODO: We cannot use root update here. This pattern is creating other ops, 424 // so if the promotion fails, those need to be cleaned up, which doesnt seem 425 // to be happening here. So to fail properly, we should be cloning the op and 426 // deleting the previous op. This needs more investigation. 427 rewriter.startRootUpdate(op); 428 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 429 if (!promotedOp) { 430 rewriter.cancelRootUpdate(op); 431 return op->emitError("subview promotion failed"); 432 } 433 rewriter.finalizeRootUpdate(op); 434 marker.replaceLinalgTransformationFilter(rewriter, op); 435 return success(); 436 } 437 438 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 439 StringRef opName, MLIRContext *context, LinalgTransformationFilter marker, 440 PatternBenefit benefit) 441 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 442 443 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 444 Operation *op, PatternRewriter &rewriter) const { 445 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 446 if (!linalgOp) 447 return failure(); 448 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 449 return failure(); 450 if (failed(vectorizeLinalgOpPrecondition(op))) 451 return failure(); 452 vectorizeLinalgOp(rewriter, op); 453 rewriter.eraseOp(op); 454 return success(); 455 } 456 457 LogicalResult mlir::linalg::applyStagedPatterns( 458 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 459 const FrozenRewritePatternList &stage2Patterns, 460 function_ref<LogicalResult(Operation *)> stage3Lambda) { 461 unsigned iteration = 0; 462 (void)iteration; 463 for (const auto &patterns : stage1Patterns) { 464 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 465 << *op); 466 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 467 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 468 return failure(); 469 } 470 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 471 << *op); 472 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 473 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 474 return failure(); 475 } 476 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 477 << *op); 478 if (stage3Lambda) { 479 if (failed(stage3Lambda(op))) 480 return failure(); 481 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 482 << *op); 483 } 484 } 485 return success(); 486 } 487 488 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 489 /// `ubVal` to `dims` and `stepVal` to `symbols`. 490 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 491 /// with positions matching the newly appended values. Substitute occurrences of 492 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 493 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 494 /// the induction variable is used with a positive or negative coefficient. 495 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 496 Value lbVal, Value ubVal, Value stepVal, 497 SmallVectorImpl<Value> &dims, 498 SmallVectorImpl<Value> &symbols) { 499 MLIRContext *ctx = lbVal.getContext(); 500 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 501 dims.push_back(lbVal); 502 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 503 dims.push_back(ubVal); 504 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 505 symbols.push_back(stepVal); 506 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 507 AffineExpr ee = substWithMin(expr, dimExpr, lb, 508 lb + step * ((ub - 1) - lb).floorDiv(step)); 509 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 510 return ee; 511 } 512 513 /// Traverse the `dims` and substitute known min or max expressions in place of 514 /// induction variables in `exprs`. 515 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 516 SmallVectorImpl<Value> &symbols) { 517 auto exprs = llvm::to_vector<4>(map.getResults()); 518 for (AffineExpr &expr : exprs) { 519 bool substituted = true; 520 while (substituted) { 521 substituted = false; 522 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 523 Value dim = dims[dimIdx]; 524 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 525 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 526 AffineExpr substitutedExpr; 527 if (auto forOp = scf::getForInductionVarOwner(dim)) 528 substitutedExpr = substituteLoopInExpr( 529 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 530 forOp.step(), dims, symbols); 531 532 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 533 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 534 ++idx) 535 substitutedExpr = substituteLoopInExpr( 536 expr, dimExpr, parallelForOp.lowerBound()[idx], 537 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 538 dims, symbols); 539 540 if (!substitutedExpr) 541 continue; 542 543 substituted = (substitutedExpr != expr); 544 expr = substitutedExpr; 545 } 546 } 547 548 // Cleanup and simplify the results. 549 // This needs to happen outside of the loop iterating on dims.size() since 550 // it modifies dims. 551 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 552 operands.append(symbols.begin(), symbols.end()); 553 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 554 exprs.front().getContext()); 555 556 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 557 558 // Pull in affine.apply operations and compose them fully into the 559 // result. 560 fullyComposeAffineMapAndOperands(&map, &operands); 561 canonicalizeMapAndOperands(&map, &operands); 562 map = simplifyAffineMap(map); 563 // Assign the results. 564 exprs.assign(map.getResults().begin(), map.getResults().end()); 565 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 566 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 567 568 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 569 } 570 571 assert(!exprs.empty() && "Unexpected empty exprs"); 572 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 573 } 574 575 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 576 AffineMinOp minOp, PatternRewriter &rewriter) const { 577 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 578 << "\n"); 579 580 SmallVector<Value, 4> dims(minOp.getDimOperands()), 581 symbols(minOp.getSymbolOperands()); 582 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 583 584 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 585 586 // Check whether any of the expressions, when subtracted from all other 587 // expressions, produces only >= 0 constants. If so, it is the min. 588 for (auto e : minOp.getAffineMap().getResults()) { 589 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 590 if (!e.isSymbolicOrConstant()) 591 continue; 592 593 auto isNonPositive = [](AffineExpr e) { 594 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 595 return cst.getValue() < 0; 596 return true; 597 }; 598 599 // Build the subMap and check everything is statically known to be 600 // positive. 601 SmallVector<AffineExpr, 4> subExprs; 602 subExprs.reserve(map.getNumResults()); 603 for (auto ee : map.getResults()) 604 subExprs.push_back(ee - e); 605 MLIRContext *ctx = minOp.getContext(); 606 AffineMap subMap = simplifyAffineMap( 607 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 608 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 609 if (llvm::any_of(subMap.getResults(), isNonPositive)) 610 continue; 611 612 // Static min found. 613 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 614 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 615 } else { 616 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 617 SmallVector<Value, 4> resultOperands = dims; 618 resultOperands.append(symbols.begin(), symbols.end()); 619 canonicalizeMapAndOperands(&resultMap, &resultOperands); 620 resultMap = simplifyAffineMap(resultMap); 621 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 622 resultOperands); 623 } 624 return success(); 625 } 626 627 return failure(); 628 } 629