1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/ScopeExit.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include <type_traits> 34 35 #define DEBUG_TYPE "linalg-transforms" 36 37 using namespace mlir; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filters(), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) { 60 if (f) 61 filters.push_back(f); 62 } 63 64 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 65 PatternRewriter &rewriter, Operation *op) const { 66 if (llvm::any_of(filters, 67 [&](const FilterFunction &f) { return failed(f(op)); })) 68 return failure(); 69 70 auto attr = op->template getAttrOfType<StringAttr>( 71 LinalgTransforms::kLinalgTransformMarker); 72 73 if (!attr) { 74 // 1. Has no filter case and matchDisjunction is empty. 75 if (matchDisjunction.empty()) 76 return success(); 77 78 // 2. Has no filter but was expecting a filter. 79 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 80 diag << " does not have any filter from list: "; 81 interleaveComma(matchDisjunction, diag); 82 }); 83 } 84 85 // 4. Match explicit filter. 86 for (auto filter : matchDisjunction) 87 if (attr.getValue() == filter) 88 return success(); 89 90 // 5. Fail to match. 91 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 92 diag << " does not have any filter from list: "; 93 interleaveComma(matchDisjunction, diag); 94 }); 95 } 96 97 void mlir::linalg::LinalgTransformationFilter:: 98 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 99 Operation *op) const { 100 if (replacement.hasValue()) 101 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 102 rewriter.getStringAttr(replacement.getValue().strref())); 103 else 104 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 105 rewriter.getContext())); 106 } 107 108 LinalgTilingOptions & 109 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 110 assert(!tileSizeComputationFunction && "tile sizes already set"); 111 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 112 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 113 OpBuilder::InsertionGuard guard(b); 114 b.setInsertionPointToStart( 115 &op->getParentOfType<FuncOp>().getBody().front()); 116 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 117 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 118 return v; 119 })); 120 }; 121 return *this; 122 } 123 124 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 125 assert(!tileSizeComputationFunction && "tile sizes already set"); 126 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 127 SmallVector<Value, 4> tileSizes; 128 auto linalgOp = dyn_cast<LinalgOp>(op); 129 if (!linalgOp) 130 return tileSizes; 131 Location loc = linalgOp.getLoc(); 132 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 133 AffineMap map = linalgOp.getShapesToLoopsMap(); 134 if (!map) 135 return tileSizes; 136 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 137 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 138 // size 0). 139 for (Value shapeSize : shapeSizes) 140 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 141 ? b.create<ConstantIndexOp>(loc, 0) 142 : b.create<ConstantIndexOp>(loc, 1)); 143 return tileSizes; 144 }; 145 return *this; 146 } 147 148 /// Try to compute a static bounding box for `operand` 149 /// Return success if either: 150 /// 1. The operand is already statically shaped, `result` is left unchanged. 151 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 152 /// created PadTensorOp. 153 /// Return failure if the operand cannot be padded to a static shape. 154 static LogicalResult padOperandToSmallestStaticBoundingBox( 155 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 156 const PaddingValueComputationFunction &paddingFunc, Value &result) { 157 // Already static shape, no need to pad. 158 if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) 159 return success(); 160 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 161 // Not a slice op, cannot construct a static bounding box. 162 if (!sliceOp) 163 return failure(); 164 SmallVector<int64_t> staticSizes; 165 staticSizes.reserve(opToPad.getRank(opOperand)); 166 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 167 for (auto size : shapedOp.getMixedSizes()) { 168 auto indexAttr = size.is<Attribute>() 169 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 170 : linalg::getSmallestBoundingIndex(size.get<Value>()); 171 // SmallestBoundingIndex must exist for all sizes. 172 // For now return an error if we can't find it. 173 if (!indexAttr) 174 return rewriter.notifyMatchFailure( 175 opToPad, "No constant bounding box can be found for padding"); 176 staticSizes.push_back(indexAttr.getInt()); 177 } 178 Value pad = paddingFunc(rewriter, *opOperand); 179 auto staticTensorType = RankedTensorType::get( 180 staticSizes, getElementTypeOrSelf(opOperand->get())); 181 result = linalg::PadTensorOp::createPadHighOp( 182 staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); 183 return success(); 184 } 185 186 LogicalResult 187 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 188 const PaddingValueComputationFunction &paddingFunc, 189 LinalgOp &paddedOp) { 190 Location loc = opToPad->getLoc(); 191 192 // If the op is fully static, it does not need padding. 193 // TODO: there are cases where we may still want to pad to larger sizes. 194 assert(opToPad.hasTensorSemantics() && 195 "expected operation to have tensor semantics"); 196 if (!opToPad.hasDynamicShape()) 197 return success(); 198 199 OpBuilder::InsertionGuard g(rewriter); 200 // Set IP after op because we also take the dims of the original output. 201 rewriter.setInsertionPointAfter(opToPad); 202 // Make a copy of the shaped operands and update it. 203 SmallVector<Value> newOperands; 204 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 205 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 206 Value paddedOperand; 207 // If padding was requested but the shape cannot be bounded statically then 208 // the pattern fails to apply. 209 if (failed(padOperandToSmallestStaticBoundingBox( 210 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 211 return failure(); 212 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 213 } 214 215 SmallVector<SmallVector<Value>> reifiedResultShapes; 216 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 217 .reifyResultShapes(rewriter, reifiedResultShapes))) 218 return failure(); 219 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 220 "expected same number of results"); 221 222 // Clone `opToPad` to operate on the statically padded shapes. 223 auto resultTensorTypes = 224 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 225 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 226 227 // Recover the slice out of the new static results. This keeps the original 228 // linalg op around because it uses the dims of the original results. 229 SmallVector<Value> paddedSubviewResults; 230 paddedSubviewResults.reserve(opToPad->getNumResults()); 231 for (auto en : llvm::enumerate(paddedOp->getResults())) { 232 Value paddedResult = en.value(); 233 int64_t resultNumber = en.index(); 234 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 235 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 236 SmallVector<OpFoldResult> sizes; 237 for (Value v : reifiedResultShapes[resultNumber]) 238 sizes.push_back(v); 239 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 240 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 241 loc, paddedResult, offsets, sizes, strides)); 242 } 243 rewriter.replaceOp(opToPad, paddedSubviewResults); 244 return success(); 245 } 246 247 /// Linalg base tiling pattern. 248 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 249 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 250 LinalgTransformationFilter filter, PatternBenefit benefit) 251 : RewritePattern(opName, benefit, context), filter(filter), 252 options(options) {} 253 254 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 255 MLIRContext *context, LinalgTilingOptions options, 256 LinalgTransformationFilter filter, PatternBenefit benefit) 257 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 258 options(options) {} 259 260 /// Try to peel a loop `op` and return the new result. 261 // TODO: Only scf.for loops are supported at the moment. 262 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 263 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 264 .Case<scf::ForOp>([&](scf::ForOp forOp) { 265 scf::ForOp partialIteration; 266 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 267 partialIteration))) 268 return partialIteration->getResults(); 269 assert(!partialIteration && "expected that loop was not peeled"); 270 return forOp->getResults(); 271 }) 272 .Default([&](Operation *op) { return op->getResults(); }); 273 } 274 275 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 276 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 277 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 278 if (!linalgOp) 279 return failure(); 280 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 281 return failure(); 282 283 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 284 285 if (!res) 286 return failure(); 287 // Clear filter to stop recursive pattern application. 288 filter.replaceLinalgTransformationFilter(rewriter, res->op); 289 290 // Peel loops. 291 for (int64_t loop : options.peeledLoops) { 292 assert(loop < static_cast<int64_t>(res->loops.size()) && 293 "requested peeling of non-existing loop"); 294 Operation *loopOp = res->loops[loop]; 295 SmallVector<Value, 4> loopResults = peelLoop(rewriter, loopOp); 296 // The result of the loop nest may change with peeling. 297 if (res->tensorResults.size() == loopOp->getNumResults() && 298 std::equal(res->tensorResults.begin(), res->tensorResults.end(), 299 loopOp->getResults().begin())) 300 res->tensorResults = loopResults; 301 } 302 303 // Consider padding on the fly only if the op has tensor semantics. 304 if (!options.paddingValueComputationFunction || 305 !linalgOp.hasTensorSemantics()) { 306 result = *res; 307 return success(); 308 } 309 310 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 311 // `res.op` is rewritten in static form with padded operands. 312 LinalgOp paddedOp; 313 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 314 options.paddingValueComputationFunction, 315 paddedOp))) { 316 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 317 res->op = paddedOp; 318 result = *res; 319 // Do not perform replacement of `linalgOp`, let the derived patterns 320 // do this as they see fit, from the resulting TiledLinalgOp. 321 return success(); 322 } 323 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 324 return failure(); 325 } 326 327 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 328 if (tiledOp.loops.empty()) 329 return tiledOp.op.getOperation()->getResults(); 330 return tiledOp.loops.front()->getResults(); 331 } 332 333 static ValueRange 334 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 335 if (tiledAndFusedOp.fusedLoops.empty()) 336 return tiledAndFusedOp.op.getOperation()->getResults(); 337 return tiledAndFusedOp.fusedLoops.front()->getResults(); 338 } 339 340 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 341 StringRef opName, MLIRContext *context, 342 const LinalgDependenceGraph &dependenceGraph, 343 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 344 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 345 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 346 : RewritePattern(opName, benefit, context, {}), 347 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 348 fusionOptions(fusionOptions), filter(filter), 349 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 350 351 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 352 Operation *op, PatternRewriter &rewriter) const { 353 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 354 // TODO: remove hasIndexSemantics check once index ops are supported. 355 if (!linalgOp || linalgOp.hasIndexSemantics()) 356 return failure(); 357 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 358 return failure(); 359 360 DenseSet<Operation *> producers; 361 producers.insert(linalgOp); 362 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 363 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 364 // When looking at dependences into, indexingOp is always OpOperand. We 365 // could assert, but continue if this is not the case. 366 if (!operandNumber) 367 continue; 368 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 369 continue; 370 if (isa<LinalgOp>(dependence.getDependentOp())) 371 producers.insert(dependence.getDependentOp()); 372 } 373 374 SmallVector<LinalgOp, 1> fusionOps; 375 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 376 ++it) { 377 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 378 if (producerLinalgOp && producers.count(producerLinalgOp)) 379 fusionOps.push_back(producerLinalgOp); 380 } 381 fusionOps.push_back(linalgOp); 382 383 SmallVector<Value, 4> tileSizes = 384 tilingOptions.tileSizeComputationFunction(rewriter, op); 385 LinalgTilingOptions instanceTilingOptions = tilingOptions; 386 instanceTilingOptions.setTileSizes(tileSizes); 387 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 388 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 389 if (!tiledAndFusedOps) 390 return failure(); 391 392 // Tile the unfused loops; 393 SmallVector<Value, 4> unfusedLoopTileSizes; 394 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 395 for (auto tileSize : enumerate(tileSizes)) { 396 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 397 unfusedLoopTileSizes.push_back(zero); 398 else 399 unfusedLoopTileSizes.push_back(tileSize.value()); 400 } 401 // Tile the loop only if there is a non-zero tile size. 402 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 403 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 404 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 405 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 406 return cst.getValue() != 0; 407 return true; 408 })) { 409 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 410 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 411 Optional<TiledLinalgOp> unfusedTiledOp = 412 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 413 if (!unfusedTiledOp) 414 return failure(); 415 rewriter.replaceOp(tiledAndFusedOps->op, 416 getTiledOpResult(unfusedTiledOp.getValue())); 417 tiledAndFusedOps->op = unfusedTiledOp->op; 418 } 419 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 420 421 filter.replaceLinalgTransformationFilter(rewriter, 422 tiledAndFusedOps->op.getOperation()); 423 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 424 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 425 fusedOp.getOperation()); 426 } 427 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 428 originalOpMarker.replaceLinalgTransformationFilter( 429 rewriter, origProducerOp.getOperation()); 430 } 431 rewriter.updateRootInPlace(op, [&]() { 432 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 433 }); 434 return success(); 435 } 436 437 /// Linalg generic interchange pattern. 438 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 439 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 440 LinalgTransformationFilter filter, PatternBenefit benefit) 441 : OpRewritePattern(context, benefit), filter(filter), 442 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 443 444 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 445 GenericOp genericOp, PatternRewriter &rewriter) const { 446 if (failed(filter.checkAndNotify(rewriter, genericOp))) 447 return failure(); 448 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 449 return failure(); 450 451 // TODO: figure out how this interplays with named ops. In particular this 452 // should break the named op property. 453 rewriter.updateRootInPlace(genericOp, [&]() { 454 interchangeGenericOp(rewriter, genericOp, interchangeVector); 455 // New filter if specified. 456 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 457 }); 458 return success(); 459 } 460 461 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 462 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 463 LinalgTransformationFilter filter, PatternBenefit benefit) 464 : RewritePattern(opName, benefit, context, {}), filter(filter), 465 options(options) {} 466 467 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 468 Operation *op, PatternRewriter &rewriter) const { 469 if (failed(filter.checkAndNotify(rewriter, op))) 470 return failure(); 471 if (failed(promoteSubviewsPrecondition(op, options))) 472 return failure(); 473 474 // TODO: We cannot use root update here. This pattern is creating other ops, 475 // so if the promotion fails, those need to be cleaned up, which doesnt seem 476 // to be happening here. So to fail properly, we should be cloning the op and 477 // deleting the previous op. This needs more investigation. 478 rewriter.startRootUpdate(op); 479 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 480 if (!promotedOp) { 481 rewriter.cancelRootUpdate(op); 482 return op->emitError("subview promotion failed"); 483 } 484 rewriter.finalizeRootUpdate(op); 485 filter.replaceLinalgTransformationFilter(rewriter, op); 486 return success(); 487 } 488 489 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 490 MLIRContext *context, LinalgTransformationFilter filter, 491 PatternBenefit benefit) 492 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 493 494 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 495 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 496 PatternBenefit benefit) 497 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 498 499 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 500 Operation *op, PatternRewriter &rewriter) const { 501 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 502 if (!linalgOp) 503 return failure(); 504 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 505 return failure(); 506 SmallVector<Value> newResults; 507 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 508 return failure(); 509 if (!newResults.empty()) 510 rewriter.replaceOp(op, newResults); 511 else 512 rewriter.eraseOp(op); 513 return success(); 514 } 515 516 LogicalResult mlir::linalg::applyStagedPatterns( 517 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 518 const FrozenRewritePatternSet &stage2Patterns, 519 function_ref<LogicalResult(Operation *)> stage3Lambda) { 520 unsigned iteration = 0; 521 (void)iteration; 522 for (const auto &patterns : stage1Patterns) { 523 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 524 << *op); 525 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 526 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 527 return failure(); 528 } 529 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 530 << *op); 531 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 532 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 533 return failure(); 534 } 535 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 536 << *op); 537 if (stage3Lambda) { 538 if (failed(stage3Lambda(op))) 539 return failure(); 540 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 541 << *op); 542 } 543 } 544 return success(); 545 } 546 547 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 548 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 549 } 550 551 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 552 /// with pad_val) and GenericOp (to copy contents). 553 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 554 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 555 556 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 557 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 558 559 // Bail on non-static shapes. 560 if (!inputShapedType.hasStaticShape()) 561 return failure(); 562 if (!resultShapedType.hasStaticShape()) 563 return failure(); 564 565 // Only support padding with a constant for now, i.e. either: 566 // 1. A BBarg from a different block. 567 // 2. A value defined outside of the current block. 568 Block &block = padOp.region().front(); 569 auto yieldOp = cast<YieldOp>(block.getTerminator()); 570 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 571 Value padValue = yieldOp.values().front(); 572 Operation *definingOp = padValue.getDefiningOp(); 573 if (definingOp && definingOp->getBlock() == &block) 574 return failure(); 575 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 576 return failure(); 577 578 // Create tensor with the padded shape 579 Location loc = padOp.getLoc(); 580 SmallVector<Value> indices(resultShapedType.getRank(), 581 rewriter.create<ConstantIndexOp>(loc, 0)); 582 Value initTensor = rewriter.create<InitTensorOp>( 583 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 584 585 // Initialize tensor with the pad value 586 Value tmpTensor = 587 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 588 589 // Copy original contents into new tensor 590 // Uses linalg.generic, but could be done with tensor.insert_slice 591 SmallVector<AffineExpr, 4> outputExprs; 592 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 593 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 594 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 595 } 596 597 SmallVector<AffineMap, 2> transferMaps = { 598 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 599 AffineMap::get(resultShapedType.getRank(), 600 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 601 602 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 603 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 604 getNParallelLoopsAttrs(resultShapedType.getRank()), 605 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 606 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 607 }); 608 609 return success(); 610 } 611 612 /// Filling `dest` using FillOp constant padding value if possible. 613 /// Otherwise, generate a tensor::GenerateOp. 614 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 615 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 616 const SmallVector<Value> &dynSizes) const { 617 auto padValue = padOp.getConstantPaddingValue(); 618 if (padValue) 619 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 620 621 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 622 auto generateOp = rewriter.create<tensor::GenerateOp>( 623 padOp.getLoc(), padOp.getResultType(), dynSizes); 624 // Copy region to new op. 625 BlockAndValueMapping bvm; 626 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 627 // Rewrite linalg::YieldOp to tensor::YieldOp. 628 OpBuilder::InsertionGuard guard(rewriter); 629 auto yieldOp = 630 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 631 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 632 assert(yieldOp.values().size() == 1); 633 rewriter.setInsertionPoint(yieldOp); 634 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 635 return generateOp; 636 } 637 638 LogicalResult 639 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 640 PatternRewriter &rewriter) const { 641 // Given an OpFoldResult, return an index-typed value. 642 auto getIdxValue = [&](OpFoldResult ofr) { 643 if (auto val = ofr.dyn_cast<Value>()) 644 return val; 645 return rewriter 646 .create<ConstantIndexOp>( 647 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 648 .getResult(); 649 }; 650 651 auto resultType = padOp.getResultType(); 652 // Compute size of InitTensorOp. Any combination of static/dynamic is 653 // supported. 654 SmallVector<Value> dynSizes; 655 SmallVector<int64_t> staticSizes; 656 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 657 if (resultType.isDynamicDim(dim)) { 658 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 659 padOp.source(), dim); 660 // Add low and high padding value. 661 auto plusLow = rewriter.createOrFold<AddIOp>( 662 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 663 auto plusHigh = rewriter.createOrFold<AddIOp>( 664 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 665 dynSizes.push_back(plusHigh); 666 } 667 staticSizes.push_back(resultType.getDimSize(dim)); 668 } 669 670 // Init tensor and fill it with padding. 671 Value init = rewriter.create<InitTensorOp>( 672 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 673 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 674 675 // Try optimize the copy of source. 676 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 677 return success(); 678 679 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 680 // for copying the PadOp source. 681 auto sourceType = padOp.getSourceType(); 682 // Compute size of source of PadTensorOp. 683 SmallVector<OpFoldResult> srcSizes; 684 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 685 if (sourceType.isDynamicDim(dim)) { 686 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 687 padOp.getLoc(), padOp.source(), dim)); 688 } else { 689 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 690 } 691 } 692 // Strides of InsertSliceOp are all 1. 693 SmallVector<OpFoldResult> strides(sourceType.getRank(), 694 rewriter.getIndexAttr(1)); 695 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 696 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 697 698 return success(); 699 } 700 701 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 702 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 703 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 704 if (!padOp) 705 return failure(); 706 // Only unit stride supported. 707 if (!sliceOp.hasUnitStride()) 708 return failure(); 709 710 Operation *tiledPadOp = padOp.getTiledImplementation( 711 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 712 sliceOp.getMixedSizes()); 713 // All shapes are static and the data source is actually used. Rewrite into 714 // pad_tensor(subtensor(x)). 715 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 716 return success(); 717 } 718