1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/ScopeExit.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::linalg; 35 36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 37 38 //===----------------------------------------------------------------------===// 39 // Transformations exposed as rewrite patterns. 40 //===----------------------------------------------------------------------===// 41 // Marker used as attribute name in generated Linalg rewriting transformations. 42 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 43 "__internal_linalg_transform__"; 44 45 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 46 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 47 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 48 replacement(replacement) {} 49 50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 51 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 52 Optional<Identifier> replacement) 53 : filters(), 54 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 55 replacement(replacement) { 56 if (f) 57 filters.push_back(f); 58 } 59 60 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 61 PatternRewriter &rewriter, Operation *op) const { 62 if (llvm::any_of(filters, 63 [&](const FilterFunction &f) { return failed(f(op)); })) 64 return failure(); 65 66 auto attr = op->template getAttrOfType<StringAttr>( 67 LinalgTransforms::kLinalgTransformMarker); 68 69 if (!attr) { 70 // 1. Has no filter case and matchDisjunction is empty. 71 if (matchDisjunction.empty()) 72 return success(); 73 74 // 2. Has no filter but was expecting a filter. 75 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 76 diag << " does not have any filter from list: "; 77 interleaveComma(matchDisjunction, diag); 78 }); 79 } 80 81 // 4. Match explicit filter. 82 for (auto filter : matchDisjunction) 83 if (attr.getValue() == filter) 84 return success(); 85 86 // 5. Fail to match. 87 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 88 diag << " does not have any filter from list: "; 89 interleaveComma(matchDisjunction, diag); 90 }); 91 } 92 93 void mlir::linalg::LinalgTransformationFilter:: 94 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 95 Operation *op) const { 96 if (replacement.hasValue()) 97 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getStringAttr(replacement.getValue().strref())); 99 else 100 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 101 rewriter.getContext())); 102 } 103 104 LinalgTilingOptions & 105 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 106 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 107 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 108 OpBuilder::InsertionGuard guard(b); 109 b.setInsertionPointToStart( 110 &op->getParentOfType<FuncOp>().getBody().front()); 111 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 112 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 113 return v; 114 })); 115 }; 116 return *this; 117 } 118 119 /// Try to compute a static bounding box for `operand` 120 /// Return success if either: 121 /// 1. The operand is already statically shaped, `result` is left unchanged. 122 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 123 /// created PadTensorOp. 124 /// Return failure if the operand cannot be padded to a static shape. 125 static LogicalResult padOperandToSmallestStaticBoundingBox( 126 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand, 127 const LinalgTilingOptions &options, Value &result) { 128 auto tensorType = operand.get().getType().cast<RankedTensorType>(); 129 // Already static shape, no need to pad. 130 if (tensorType.hasStaticShape()) 131 return success(); 132 auto subtensor = operand.get().getDefiningOp<SubTensorOp>(); 133 // Not a subtensor, cannot construct a static bounding box. 134 if (!subtensor) 135 return failure(); 136 SmallVector<int64_t> staticSizes; 137 staticSizes.reserve(tensorType.getRank()); 138 auto shapedOp = 139 cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = options.paddingValueComputationFunction(rewriter, operand); 152 auto staticTensorType = 153 RankedTensorType::get(staticSizes, tensorType.getElementType()); 154 result = linalg::PadTensorOp::createPadHighOp( 155 staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 // Try to create a static bounding box around each operand of `res.op`. 160 // If successful, `res.op` is rewritten in static form with padded operands. 161 // `res.op` is updated to the cloned static form of the op on success. 162 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 163 TiledLinalgOp &res, 164 const LinalgTilingOptions &options) { 165 LinalgOp opToPad = res.op; 166 Location loc = opToPad->getLoc(); 167 168 // If the op is fully static, it does not need padding. 169 // TODO: there are cases where we may still want to pad to larger sizes. 170 assert(opToPad.hasTensorSemantics() && 171 "expected operation to have tensor semantics"); 172 if (!opToPad.hasDynamicShape()) 173 return success(); 174 175 OpBuilder::InsertionGuard g(rewriter); 176 // Set IP after op because we also take the dims of the original output. 177 rewriter.setInsertionPointAfter(opToPad); 178 // Make a copy of the shaped operands and update it. 179 SmallVector<Value> newOperands; 180 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 181 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 182 Value paddedOperand; 183 // If padding was requested but the shape cannot be bounded statically then 184 // the pattern fails to apply. 185 if (failed(padOperandToSmallestStaticBoundingBox( 186 rewriter, opToPad, *opOperand, options, paddedOperand))) { 187 return failure(); 188 } 189 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 190 } 191 192 // Clone `opToPad` to operate on the statically padded shapes. 193 auto resultTensorTypes = 194 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 195 ValueRange otherOperands = opToPad.getAssumedNonShapedOperands(); 196 newOperands.append(otherOperands.begin(), otherOperands.end()); 197 linalg::LinalgOp paddedOp = 198 opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 199 200 // Recover the subtensor out of the new static results. This keeps the 201 // original linalg op around because it uses the dims of the original results. 202 // This later folds away. 203 SmallVector<Value> paddedSubviewResults; 204 paddedSubviewResults.reserve(opToPad->getNumResults()); 205 SetVector<Operation *> newUsersOfOpToPad; 206 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 207 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 208 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 209 auto sizes = llvm::to_vector<4>(llvm::map_range( 210 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 211 auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d); 212 newUsersOfOpToPad.insert(dimOp); 213 return dimOp.getResult(); 214 })); 215 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 216 paddedSubviewResults.push_back(rewriter.create<SubTensorOp>( 217 loc, std::get<1>(it), offsets, sizes, strides)); 218 } 219 // Replace the transient `opToPad` locally, except for uses that we just 220 // created for the purpose of extracting the dims. 221 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 222 return !newUsersOfOpToPad.contains(opOp.getOwner()); 223 }); 224 225 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 226 return success(); 227 } 228 229 /// Linalg base tiling pattern. 230 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 231 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 232 LinalgTransformationFilter filter, PatternBenefit benefit) 233 : RewritePattern(opName, benefit, context), filter(filter), 234 options(options) {} 235 236 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 237 MLIRContext *context, LinalgTilingOptions options, 238 LinalgTransformationFilter filter, PatternBenefit benefit) 239 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 240 options(options) {} 241 242 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 243 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 244 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 245 if (!linalgOp) 246 return failure(); 247 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 248 return failure(); 249 250 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 251 252 if (!res) 253 return failure(); 254 255 // Setup RAII guard to return properly. 256 LinalgOp tiledOp = res->op; 257 auto guard = llvm::make_scope_exit([&]() { 258 // Return relevant information to derived pattern. 259 result = *res; 260 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 261 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 262 if (tiledOp != res->op) 263 filter.replaceLinalgTransformationFilter(rewriter, res->op); 264 }); 265 266 // Consider padding on the fly only if the op has tensor semantics. 267 if (!options.paddingValueComputationFunction || 268 !linalgOp.hasTensorSemantics()) 269 return success(); 270 271 // Try to pad on the fly by rewriting res->op as a padded op. 272 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 273 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 274 return failure(); 275 } 276 277 // Do not perform replacement of `linalgOp`, let the derived patterns 278 // do this as they see fit, from the resulting TiledLinalgOp. 279 return success(); 280 } 281 282 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 283 if (tiledOp.loops.empty()) 284 return tiledOp.op.getOperation()->getResults(); 285 return tiledOp.loops.front()->getResults(); 286 } 287 288 static ValueRange 289 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 290 if (tiledAndFusedOp.fusedLoops.empty()) 291 return tiledAndFusedOp.op.getOperation()->getResults(); 292 return tiledAndFusedOp.fusedLoops.front()->getResults(); 293 } 294 295 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 296 StringRef opName, MLIRContext *context, 297 const LinalgDependenceGraph &dependenceGraph, 298 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 299 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 300 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 301 : RewritePattern(opName, benefit, context, {}), 302 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 303 fusionOptions(fusionOptions), filter(filter), 304 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 305 306 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 307 Operation *op, PatternRewriter &rewriter) const { 308 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 309 // TODO: remove hasIndexSemantics check once index ops are supported. 310 if (!linalgOp || linalgOp.hasIndexSemantics()) 311 return failure(); 312 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 313 return failure(); 314 315 DenseSet<Operation *> producers; 316 producers.insert(linalgOp); 317 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 318 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 319 // When looking at dependences into, indexingOp is always OpOperand. We 320 // could assert, but continue if this is not the case. 321 if (!operandNumber) 322 continue; 323 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 324 continue; 325 if (isa<LinalgOp>(dependence.getDependentOp())) 326 producers.insert(dependence.getDependentOp()); 327 } 328 329 SmallVector<LinalgOp, 1> fusionOps; 330 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 331 ++it) { 332 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 333 if (producerLinalgOp && producers.count(producerLinalgOp)) 334 fusionOps.push_back(producerLinalgOp); 335 } 336 fusionOps.push_back(linalgOp); 337 338 SmallVector<Value, 4> tileSizes = 339 tilingOptions.tileSizeComputationFunction(rewriter, op); 340 LinalgTilingOptions instanceTilingOptions = tilingOptions; 341 instanceTilingOptions.setTileSizes(tileSizes); 342 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 343 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 344 if (!tiledAndFusedOps) 345 return failure(); 346 347 // Tile the unfused loops; 348 SmallVector<Value, 4> unfusedLoopTileSizes; 349 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 350 for (auto tileSize : enumerate(tileSizes)) { 351 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 352 unfusedLoopTileSizes.push_back(zero); 353 else 354 unfusedLoopTileSizes.push_back(tileSize.value()); 355 } 356 // Tile the loop only if there is a non-zero tile size. 357 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 358 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 359 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 360 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 361 return cst.getValue() != 0; 362 return true; 363 })) { 364 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 365 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 366 Optional<TiledLinalgOp> unfusedTiledOp = 367 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 368 if (!unfusedTiledOp) 369 return failure(); 370 rewriter.replaceOp(tiledAndFusedOps->op, 371 getTiledOpResult(unfusedTiledOp.getValue())); 372 tiledAndFusedOps->op = unfusedTiledOp->op; 373 } 374 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 375 376 filter.replaceLinalgTransformationFilter(rewriter, 377 tiledAndFusedOps->op.getOperation()); 378 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 379 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 380 fusedOp.getOperation()); 381 } 382 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 383 originalOpMarker.replaceLinalgTransformationFilter( 384 rewriter, origProducerOp.getOperation()); 385 } 386 rewriter.updateRootInPlace(op, [&]() { 387 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 388 }); 389 return success(); 390 } 391 392 /// Linalg generic interchange pattern. 393 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 394 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 395 LinalgTransformationFilter filter, PatternBenefit benefit) 396 : OpRewritePattern(context, benefit), filter(filter), 397 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 398 399 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 400 GenericOp genericOp, PatternRewriter &rewriter) const { 401 if (failed(filter.checkAndNotify(rewriter, genericOp))) 402 return failure(); 403 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 404 return failure(); 405 406 // TODO: figure out how this interplays with named ops. In particular this 407 // should break the named op property. 408 rewriter.updateRootInPlace(genericOp, [&]() { 409 interchangeGenericOp(rewriter, genericOp, interchangeVector); 410 // New filter if specified. 411 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 412 }); 413 return success(); 414 } 415 416 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 417 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 418 LinalgTransformationFilter filter, PatternBenefit benefit) 419 : RewritePattern(opName, benefit, context, {}), filter(filter), 420 options(options) {} 421 422 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 423 Operation *op, PatternRewriter &rewriter) const { 424 if (failed(filter.checkAndNotify(rewriter, op))) 425 return failure(); 426 if (failed(promoteSubviewsPrecondition(op, options))) 427 return failure(); 428 429 // TODO: We cannot use root update here. This pattern is creating other ops, 430 // so if the promotion fails, those need to be cleaned up, which doesnt seem 431 // to be happening here. So to fail properly, we should be cloning the op and 432 // deleting the previous op. This needs more investigation. 433 rewriter.startRootUpdate(op); 434 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 435 if (!promotedOp) { 436 rewriter.cancelRootUpdate(op); 437 return op->emitError("subview promotion failed"); 438 } 439 rewriter.finalizeRootUpdate(op); 440 filter.replaceLinalgTransformationFilter(rewriter, op); 441 return success(); 442 } 443 444 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 445 MLIRContext *context, LinalgTransformationFilter filter, 446 PatternBenefit benefit) 447 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 448 449 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 450 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 451 PatternBenefit benefit) 452 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 453 454 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 455 Operation *op, PatternRewriter &rewriter) const { 456 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 457 if (!linalgOp) 458 return failure(); 459 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 460 return failure(); 461 SmallVector<Value> newResults; 462 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 463 return failure(); 464 if (!newResults.empty()) 465 rewriter.replaceOp(op, newResults); 466 else 467 rewriter.eraseOp(op); 468 return success(); 469 } 470 471 LogicalResult mlir::linalg::applyStagedPatterns( 472 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 473 const FrozenRewritePatternSet &stage2Patterns, 474 function_ref<LogicalResult(Operation *)> stage3Lambda) { 475 unsigned iteration = 0; 476 (void)iteration; 477 for (const auto &patterns : stage1Patterns) { 478 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 479 << *op); 480 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 481 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 482 return failure(); 483 } 484 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 485 << *op); 486 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 487 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 488 return failure(); 489 } 490 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 491 << *op); 492 if (stage3Lambda) { 493 if (failed(stage3Lambda(op))) 494 return failure(); 495 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 496 << *op); 497 } 498 } 499 return success(); 500 } 501 502 /// Traverse the `dims` and substitute known min or max expressions returned by 503 /// the lambda |getMinMaxExpr|. 504 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 505 SmallVectorImpl<Value> &symbols, 506 GetMinMaxExprFn getMinMaxExpr) { 507 auto exprs = llvm::to_vector<4>(map.getResults()); 508 for (AffineExpr &expr : exprs) { 509 bool substituted = true; 510 while (substituted) { 511 substituted = false; 512 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 513 Value dim = dims[dimIdx]; 514 auto minMax = getMinMaxExpr(dim, dims, symbols); 515 if (!minMax) 516 continue; 517 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 518 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 519 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 520 // Substitute occurrences of `dimExpr` by either the min expression or 521 // the max expression depending on whether the value is used with a 522 // positive or negative coefficient. 523 AffineExpr substitutedExpr = 524 substWithMin(expr, dimExpr, minMax->first, minMax->second); 525 LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n"); 526 substituted = (substitutedExpr != expr); 527 expr = substitutedExpr; 528 } 529 } 530 531 // Cleanup and simplify the results. 532 // This needs to happen outside of the loop iterating on dims.size() since 533 // it modifies dims. 534 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 535 operands.append(symbols.begin(), symbols.end()); 536 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 537 exprs.front().getContext()); 538 539 LLVM_DEBUG({ 540 DBGS() << "Map to simplify: " << map << "\n"; 541 DBGS() << "Operands:\n"; 542 for (Value v : operands) 543 DBGS() << v << "\n"; 544 }); 545 546 // Pull in affine.apply operations and compose them fully into the 547 // result. 548 fullyComposeAffineMapAndOperands(&map, &operands); 549 canonicalizeMapAndOperands(&map, &operands); 550 map = simplifyAffineMap(map); 551 // Assign the results. 552 exprs.assign(map.getResults().begin(), map.getResults().end()); 553 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 554 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 555 556 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 557 } 558 559 assert(!exprs.empty() && "Unexpected empty exprs"); 560 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 561 } 562 563 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute 564 /// dimensions with known range by new expressions involving the min or max 565 /// expression: 566 /// - If the AffineDimExpr mapped to a known value has a positive sign, it 567 /// is replaced by the min expression. 568 /// - If the AffineDimExpr mapped to a known value has a negative sign, it is 569 /// replaced by the max expression. 570 /// All known values are iteratively replaced. 571 /// This is used as an intermediate step in computing bounding boxes and 572 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have 573 /// positive values (positive orthant assumptions). 574 /// Return a new AffineMap, dims and symbols that have been canonicalized and 575 /// simplified. 576 AffineMapAndOperands 577 mlir::linalg::substituteMin(AffineMinOp affineMinOp, 578 GetMinMaxExprFn getMinMaxExpr) { 579 AffineMapAndOperands res{affineMinOp.getAffineMap(), 580 SmallVector<Value>(affineMinOp.getDimOperands()), 581 SmallVector<Value>(affineMinOp.getSymbolOperands())}; 582 res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, 583 getMinMaxExpr); 584 return res; 585 } 586 587 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( 588 AffineMinOp minOp, PatternRewriter &rewriter) const { 589 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 590 << "\n"); 591 592 auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn); 593 AffineMap map = affineMapAndOperands.map; 594 595 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 596 597 // Check whether any of the expressions, when subtracted from all other 598 // expressions, produces only >= 0 constants. If so, it is the min. 599 for (auto e : minOp.getAffineMap().getResults()) { 600 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 601 if (!e.isSymbolicOrConstant()) 602 continue; 603 604 auto isNonPositive = [](AffineExpr e) { 605 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 606 return cst.getValue() < 0; 607 return true; 608 }; 609 610 // Build the subMap and check everything is statically known to be 611 // positive. 612 SmallVector<AffineExpr, 4> subExprs; 613 subExprs.reserve(map.getNumResults()); 614 for (auto ee : map.getResults()) 615 subExprs.push_back(ee - e); 616 MLIRContext *ctx = minOp.getContext(); 617 AffineMap subMap = simplifyAffineMap( 618 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 619 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 620 if (llvm::any_of(subMap.getResults(), isNonPositive)) 621 continue; 622 623 // Static min found. 624 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 625 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 626 } else { 627 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 628 SmallVector<Value> resultOperands = affineMapAndOperands.dims; 629 llvm::append_range(resultOperands, affineMapAndOperands.symbols); 630 canonicalizeMapAndOperands(&resultMap, &resultOperands); 631 resultMap = simplifyAffineMap(resultMap); 632 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 633 resultOperands); 634 } 635 return success(); 636 } 637 638 return failure(); 639 } 640 641 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 642 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 643 } 644 645 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 646 /// with pad_val) and GenericOp (to copy contents). 647 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 648 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 649 650 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 651 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 652 653 // Bail on non-static shapes. 654 if (!inputShapedType.hasStaticShape()) 655 return failure(); 656 if (!resultShapedType.hasStaticShape()) 657 return failure(); 658 659 // Only support padding with a constant for now, i.e. either: 660 // 1. A BBarg from a different block. 661 // 2. A value defined outside of the current block. 662 Block &block = padOp.region().front(); 663 auto yieldOp = cast<YieldOp>(block.getTerminator()); 664 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 665 Value padValue = yieldOp.values().front(); 666 Operation *definingOp = padValue.getDefiningOp(); 667 if (definingOp && definingOp->getBlock() == &block) 668 return failure(); 669 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 670 return failure(); 671 672 // Create tensor with the padded shape 673 Location loc = padOp.getLoc(); 674 SmallVector<Value> indices(resultShapedType.getRank(), 675 rewriter.create<ConstantIndexOp>(loc, 0)); 676 Value initTensor = rewriter.create<InitTensorOp>( 677 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 678 679 // Initialize tensor with the pad value 680 Value tmpTensor = 681 rewriter.create<linalg::FillOp>(loc, initTensor, padValue).result(); 682 683 // Copy original contents into new tensor 684 // Uses linalg.generic, but could be done with std.subtensor_insert 685 SmallVector<AffineExpr, 4> outputExprs; 686 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 687 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 688 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 689 } 690 691 SmallVector<AffineMap, 2> transferMaps = { 692 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 693 AffineMap::get(resultShapedType.getRank(), 694 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 695 696 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 697 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 698 getNParallelLoopsAttrs(resultShapedType.getRank()), 699 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 700 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 701 }); 702 703 return success(); 704 } 705