1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 50 Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 LogicalResult 55 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 56 Operation *op) const { 57 auto attr = op->template getAttrOfType<StringAttr>( 58 LinalgTransforms::kLinalgTransformMarker); 59 60 if (!attr) { 61 // 1. Has no marker case and matchDisjunction is empty. 62 if (matchDisjunction.empty()) 63 return success(); 64 65 // 2. Has no marker but was expecting a marker. 66 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 67 diag << " does not have any marker from list: "; 68 interleaveComma(matchDisjunction, diag); 69 }); 70 } 71 72 // 4. Match explicit marker. 73 for (auto marker : matchDisjunction) 74 if (attr.getValue() == marker) 75 return success(); 76 77 // 5. Fail to match. 78 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 79 diag << " does not have any marker from list: "; 80 interleaveComma(matchDisjunction, diag); 81 }); 82 } 83 84 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 85 Operation *op) const { 86 if (replacement.hasValue()) 87 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 88 rewriter.getStringAttr(replacement.getValue())); 89 else 90 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 91 rewriter.getContext())); 92 } 93 94 LinalgTilingOptions & 95 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 96 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 97 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 98 OpBuilder::InsertionGuard guard(b); 99 b.setInsertionPointToStart( 100 &op->getParentOfType<FuncOp>().getBody().front()); 101 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 102 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 103 return v; 104 })); 105 }; 106 return *this; 107 } 108 109 /// Try to compute a static bounding box for `operand` 110 /// Return success if either: 111 /// 1. The operand is already statically shaped, `result` is left unchanged. 112 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 113 /// created SimplePadOp. 114 /// Return failure if the operand cannot be padded to a static shape. 115 static LogicalResult padOperandToSmallestStaticBoundingBox( 116 PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand, 117 const LinalgTilingOptions &options, Value &result) { 118 auto tensorType = operand.getType().cast<RankedTensorType>(); 119 // Already static shape, no need to pad. 120 if (tensorType.hasStaticShape()) 121 return success(); 122 auto subtensor = operand.getDefiningOp<SubTensorOp>(); 123 // Not a subtensor, cannot construct a static bounding box. 124 if (!subtensor) 125 return failure(); 126 SmallVector<int64_t> staticSizes; 127 staticSizes.reserve(tensorType.getRank()); 128 auto shapedOp = 129 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 130 for (auto size : shapedOp.getMixedSizes()) { 131 auto indexAttr = size.is<Attribute>() 132 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 133 : linalg::getSmallestBoundingIndex(size.get<Value>()); 134 // SmallestBoundingIndex must exist for all sizes. 135 // For now return an error if we can't find it. 136 if (!indexAttr) 137 return rewriter.notifyMatchFailure( 138 opToPad, "No constant bounding box can be found for padding"); 139 staticSizes.push_back(indexAttr.getInt()); 140 } 141 Value pad = options.paddingValueComputationFunction(rewriter, opToPad); 142 auto staticTensorType = 143 RankedTensorType::get(staticSizes, tensorType.getElementType()); 144 result = rewriter.create<linalg::SimplePadOp>(opToPad->getLoc(), 145 staticTensorType, operand, pad); 146 return success(); 147 } 148 149 // Try to create a static bounding box around each operand of `res.op`. 150 // If successful, `res.op` is rewritten in static form with padded operands. 151 // `res.op` is updated to the cloned static form of the op on success. 152 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 153 TiledLinalgOp &res, 154 const LinalgTilingOptions &options) { 155 LinalgOp opToPad = res.op; 156 Location loc = opToPad->getLoc(); 157 158 // If the op is fully static, it does not need padding. 159 // TODO: there are cases where we may still want to pad to larger sizes. 160 if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) { 161 return v.getType().cast<RankedTensorType>().hasStaticShape(); 162 })) 163 return success(); 164 165 OpBuilder::InsertionGuard g(rewriter); 166 // Set IP after op because we also take the dims of the original output. 167 rewriter.setInsertionPointAfter(opToPad); 168 // Make a copy of the shaped operands and update it. 169 SmallVector<Value> operands = opToPad.getShapedOperands(); 170 for (Value &v : operands) { 171 Value paddedOperand; 172 // If padding was requested but the shape cannot be bounded statically then 173 // the pattern fails to apply. 174 if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v, 175 options, paddedOperand))) { 176 return failure(); 177 } 178 // Update v if we indeed got a padded operand. 179 v = paddedOperand ? paddedOperand : v; 180 } 181 182 // Clone `opToPad` to operate on the statically padded shapes. 183 auto resultTensorTypes = 184 ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes(); 185 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 186 operands.append(otherOperands.begin(), otherOperands.end()); 187 linalg::LinalgOp paddedOp = 188 opToPad.clone(rewriter, loc, resultTensorTypes, operands); 189 190 // Recover the subtensor out of the new static results. This keeps the 191 // original linalg op around because it uses the dims of the original results. 192 // This later folds away. 193 SmallVector<Value> paddedSubviewResults; 194 paddedSubviewResults.reserve(opToPad->getNumResults()); 195 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 196 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 197 llvm::SetVector<Operation *> newUsersOfOpToPad; 198 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 199 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 200 SmallVector<Value> offsets(rank, zero); 201 auto sizes = llvm::to_vector<4>( 202 llvm::map_range(llvm::seq<unsigned>(0, rank), [&](unsigned d) -> Value { 203 auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d); 204 newUsersOfOpToPad.insert(dimOp); 205 return dimOp; 206 })); 207 SmallVector<Value> strides(rank, one); 208 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 209 loc, std::get<1>(it), offsets, sizes, strides)); 210 } 211 // Replace the transient `opToPad` locally, except for uses that we just 212 // created for the purpose of extracting the dims. 213 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 214 return !newUsersOfOpToPad.contains(opOp.getOwner()); 215 }); 216 217 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 218 return success(); 219 } 220 221 /// Linalg base tiling pattern. 222 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 223 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 224 LinalgMarker marker, PatternBenefit benefit) 225 : RewritePattern(opName, {}, benefit, context), marker(marker), 226 options(options) {} 227 228 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 229 LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) 230 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), 231 options(options) {} 232 233 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 234 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 235 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 236 if (!linalgOp) 237 return failure(); 238 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 239 return failure(); 240 241 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 242 243 if (!res) 244 return failure(); 245 246 // Setup RAII guard to return properly. 247 bool succeeded = true; 248 LinalgOp tiledOp = res->op; 249 auto guard = llvm::make_scope_exit([&]() { 250 if (!succeeded) 251 return; 252 // Return relevant information to derived pattern. 253 result = *res; 254 // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary. 255 marker.replaceLinalgMarker(rewriter, tiledOp); 256 if (tiledOp != res->op) 257 marker.replaceLinalgMarker(rewriter, res->op); 258 }); 259 260 // Consider padding on the fly only if the op has tensor semantics. 261 if (!options.paddingValueComputationFunction || 262 !linalgOp.hasTensorSemantics()) 263 return success(); 264 265 // Try to pad on the fly by rewriting res->op as a padded op. 266 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 267 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 268 succeeded = false; 269 return failure(); 270 } 271 272 // Do not perform replacement of `linalgOp`, let the derived patterns 273 // do this as they see fit, from the resulting TiledLinalgOp. 274 return success(); 275 } 276 277 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 278 StringRef opName, MLIRContext *context, 279 const LinalgDependenceGraph &dependenceGraph, 280 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 281 LinalgMarker marker, LinalgMarker fusedOpMarker, 282 LinalgMarker originalOpMarker, PatternBenefit benefit) 283 : RewritePattern(opName, {}, benefit, context), 284 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 285 fusionOptions(fusionOptions), marker(marker), 286 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 287 288 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 289 Operation *op, PatternRewriter &rewriter) const { 290 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 291 if (!linalgOp) 292 return failure(); 293 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 294 return failure(); 295 if (!linalgOp.hasBufferSemantics()) 296 return failure(); 297 298 DenseSet<Operation *> producers; 299 producers.insert(linalgOp); 300 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 301 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 302 // When looking at dependences into, indexingOp is always OpOperand. We 303 // could assert, but continue if this is not the case. 304 if (!operandNumber) 305 continue; 306 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 307 continue; 308 if (isa<LinalgOp>(dependence.getDependentOp())) 309 producers.insert(dependence.getDependentOp()); 310 } 311 312 SmallVector<LinalgOp, 1> fusionOps; 313 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 314 ++it) { 315 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 316 if (producerLinalgOp && producers.count(producerLinalgOp)) 317 fusionOps.push_back(producerLinalgOp); 318 } 319 fusionOps.push_back(linalgOp); 320 321 SmallVector<Value, 4> tileSizes = 322 tilingOptions.tileSizeComputationFunction(rewriter, op); 323 LinalgTilingOptions instanceTilingOptions = tilingOptions; 324 instanceTilingOptions.setTileSizes(tileSizes); 325 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 326 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 327 if (!tiledAndFusedOps) 328 return failure(); 329 330 // Tile the unfused loops; 331 SmallVector<Value, 4> unfusedLoopTileSizes; 332 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 333 for (auto tileSize : enumerate(tileSizes)) { 334 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 335 unfusedLoopTileSizes.push_back(zero); 336 else 337 unfusedLoopTileSizes.push_back(tileSize.value()); 338 } 339 // Tile the loop only if there is a non-zero tile size. 340 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 341 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 342 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 343 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 344 return cst.getValue() != 0; 345 return true; 346 })) { 347 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 348 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 349 Optional<TiledLinalgOp> unfusedTiledOp = 350 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 351 if (!unfusedTiledOp) 352 return failure(); 353 rewriter.eraseOp(tiledAndFusedOps->op); 354 tiledAndFusedOps->op = unfusedTiledOp->op; 355 } 356 357 marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); 358 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 359 fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); 360 } 361 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 362 originalOpMarker.replaceLinalgMarker(rewriter, 363 origProducerOp.getOperation()); 364 } 365 rewriter.updateRootInPlace( 366 op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); 367 return success(); 368 } 369 370 /// Linalg base interchange pattern. 371 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 372 StringRef opName, MLIRContext *context, 373 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 374 PatternBenefit benefit) 375 : RewritePattern(opName, {}, benefit, context), marker(marker), 376 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 377 378 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 379 Operation *op, PatternRewriter &rewriter) const { 380 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 381 if (!linalgOp) 382 return failure(); 383 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 384 return failure(); 385 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 386 return failure(); 387 388 // TODO: figure out how this interplays with named ops. In particular this 389 // should break the named op property. 390 rewriter.updateRootInPlace(op, [&]() { 391 interchange(linalgOp, interchangeVector); 392 // New marker if specified. 393 marker.replaceLinalgMarker(rewriter, op); 394 }); 395 return success(); 396 } 397 398 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 399 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 400 LinalgMarker marker, PatternBenefit benefit) 401 : RewritePattern(opName, {}, benefit, context), marker(marker), 402 options(options) {} 403 404 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 405 Operation *op, PatternRewriter &rewriter) const { 406 if (failed(marker.checkAndNotify(rewriter, op))) 407 return failure(); 408 if (failed(promoteSubviewsPrecondition(op, options))) 409 return failure(); 410 411 // TODO: We cannot use root update here. This pattern is creating other ops, 412 // so if the promotion fails, those need to be cleaned up, which doesnt seem 413 // to be happening here. So to fail properly, we should be cloning the op and 414 // deleting the previous op. This needs more investigation. 415 rewriter.startRootUpdate(op); 416 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 417 if (!promotedOp) { 418 rewriter.cancelRootUpdate(op); 419 return op->emitError("subview promotion failed"); 420 } 421 rewriter.finalizeRootUpdate(op); 422 marker.replaceLinalgMarker(rewriter, op); 423 return success(); 424 } 425 426 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 427 StringRef opName, MLIRContext *context, LinalgMarker marker, 428 PatternBenefit benefit) 429 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 430 431 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 432 Operation *op, PatternRewriter &rewriter) const { 433 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 434 if (!linalgOp) 435 return failure(); 436 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 437 return failure(); 438 if (failed(vectorizeLinalgOpPrecondition(op))) 439 return failure(); 440 vectorizeLinalgOp(rewriter, op); 441 rewriter.eraseOp(op); 442 return success(); 443 } 444 445 LogicalResult mlir::linalg::applyStagedPatterns( 446 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 447 const FrozenRewritePatternList &stage2Patterns, 448 function_ref<LogicalResult(Operation *)> stage3Lambda) { 449 unsigned iteration = 0; 450 (void)iteration; 451 for (const auto &patterns : stage1Patterns) { 452 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 453 << *op); 454 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 455 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 456 return failure(); 457 } 458 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 459 << *op); 460 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 461 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 462 return failure(); 463 } 464 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 465 << *op); 466 if (stage3Lambda) { 467 if (failed(stage3Lambda(op))) 468 return failure(); 469 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 470 << *op); 471 } 472 } 473 return success(); 474 } 475 476 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 477 /// `ubVal` to `dims` and `stepVal` to `symbols`. 478 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 479 /// with positions matching the newly appended values. Substitute occurrences of 480 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 481 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 482 /// the induction variable is used with a positive or negative coefficient. 483 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 484 Value lbVal, Value ubVal, Value stepVal, 485 SmallVectorImpl<Value> &dims, 486 SmallVectorImpl<Value> &symbols) { 487 MLIRContext *ctx = lbVal.getContext(); 488 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 489 dims.push_back(lbVal); 490 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 491 dims.push_back(ubVal); 492 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 493 symbols.push_back(stepVal); 494 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 495 AffineExpr ee = substWithMin(expr, dimExpr, lb, 496 lb + step * ((ub - 1) - lb).floorDiv(step)); 497 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 498 return ee; 499 } 500 501 /// Traverse the `dims` and substitute known min or max expressions in place of 502 /// induction variables in `exprs`. 503 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 504 SmallVectorImpl<Value> &symbols) { 505 auto exprs = llvm::to_vector<4>(map.getResults()); 506 for (AffineExpr &expr : exprs) { 507 bool substituted = true; 508 while (substituted) { 509 substituted = false; 510 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 511 Value dim = dims[dimIdx]; 512 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 513 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 514 AffineExpr substitutedExpr; 515 if (auto forOp = scf::getForInductionVarOwner(dim)) 516 substitutedExpr = substituteLoopInExpr( 517 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 518 forOp.step(), dims, symbols); 519 520 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 521 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 522 ++idx) 523 substitutedExpr = substituteLoopInExpr( 524 expr, dimExpr, parallelForOp.lowerBound()[idx], 525 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 526 dims, symbols); 527 528 if (!substitutedExpr) 529 continue; 530 531 substituted = (substitutedExpr != expr); 532 expr = substitutedExpr; 533 } 534 } 535 536 // Cleanup and simplify the results. 537 // This needs to happen outside of the loop iterating on dims.size() since 538 // it modifies dims. 539 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 540 operands.append(symbols.begin(), symbols.end()); 541 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 542 exprs.front().getContext()); 543 544 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 545 546 // Pull in affine.apply operations and compose them fully into the 547 // result. 548 fullyComposeAffineMapAndOperands(&map, &operands); 549 canonicalizeMapAndOperands(&map, &operands); 550 map = simplifyAffineMap(map); 551 // Assign the results. 552 exprs.assign(map.getResults().begin(), map.getResults().end()); 553 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 554 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 555 556 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 557 } 558 559 assert(!exprs.empty() && "Unexpected empty exprs"); 560 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 561 } 562 563 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 564 AffineMinOp minOp, PatternRewriter &rewriter) const { 565 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 566 << "\n"); 567 568 SmallVector<Value, 4> dims(minOp.getDimOperands()), 569 symbols(minOp.getSymbolOperands()); 570 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 571 572 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 573 574 // Check whether any of the expressions, when subtracted from all other 575 // expressions, produces only >= 0 constants. If so, it is the min. 576 for (auto e : minOp.getAffineMap().getResults()) { 577 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 578 if (!e.isSymbolicOrConstant()) 579 continue; 580 581 auto isNonPositive = [](AffineExpr e) { 582 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 583 return cst.getValue() < 0; 584 return true; 585 }; 586 587 // Build the subMap and check everything is statically known to be 588 // positive. 589 SmallVector<AffineExpr, 4> subExprs; 590 subExprs.reserve(map.getNumResults()); 591 for (auto ee : map.getResults()) 592 subExprs.push_back(ee - e); 593 MLIRContext *ctx = minOp.getContext(); 594 AffineMap subMap = simplifyAffineMap( 595 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 596 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 597 if (llvm::any_of(subMap.getResults(), isNonPositive)) 598 continue; 599 600 // Static min found. 601 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 602 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 603 } else { 604 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 605 SmallVector<Value, 4> resultOperands = dims; 606 resultOperands.append(symbols.begin(), symbols.end()); 607 canonicalizeMapAndOperands(&resultMap, &resultOperands); 608 resultMap = simplifyAffineMap(resultMap); 609 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 610 resultOperands); 611 } 612 return success(); 613 } 614 615 return failure(); 616 } 617