1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 48 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 53 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 54 Optional<Identifier> replacement) 55 : filters(), 56 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 57 replacement(replacement) { 58 if (f) 59 filters.push_back(f); 60 } 61 62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 63 PatternRewriter &rewriter, Operation *op) const { 64 if (llvm::any_of(filters, 65 [&](const FilterFunction &f) { return failed(f(op)); })) 66 return failure(); 67 68 auto attr = op->template getAttrOfType<StringAttr>( 69 LinalgTransforms::kLinalgTransformMarker); 70 71 if (!attr) { 72 // 1. Has no filter case and matchDisjunction is empty. 73 if (matchDisjunction.empty()) 74 return success(); 75 76 // 2. Has no filter but was expecting a filter. 77 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 78 diag << " does not have any filter from list: "; 79 interleaveComma(matchDisjunction, diag); 80 }); 81 } 82 83 // 4. Match explicit filter. 84 for (auto filter : matchDisjunction) 85 if (attr.getValue() == filter) 86 return success(); 87 88 // 5. Fail to match. 89 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 90 diag << " does not have any filter from list: "; 91 interleaveComma(matchDisjunction, diag); 92 }); 93 } 94 95 void mlir::linalg::LinalgTransformationFilter:: 96 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 97 Operation *op) const { 98 if (replacement.hasValue()) 99 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 100 rewriter.getStringAttr(replacement.getValue().strref())); 101 else 102 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getContext())); 104 } 105 106 LinalgTilingOptions & 107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 108 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 109 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 110 OpBuilder::InsertionGuard guard(b); 111 b.setInsertionPointToStart( 112 &op->getParentOfType<FuncOp>().getBody().front()); 113 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 114 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 115 return v; 116 })); 117 }; 118 return *this; 119 } 120 121 /// Try to compute a static bounding box for `operand` 122 /// Return success if either: 123 /// 1. The operand is already statically shaped, `result` is left unchanged. 124 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 125 /// created PadTensorOp. 126 /// Return failure if the operand cannot be padded to a static shape. 127 static LogicalResult padOperandToSmallestStaticBoundingBox( 128 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 129 const PaddingValueComputationFunction &paddingFunc, Value &result) { 130 // Already static shape, no need to pad. 131 if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) 132 return success(); 133 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 134 // Not a slice op, cannot construct a static bounding box. 135 if (!sliceOp) 136 return failure(); 137 SmallVector<int64_t> staticSizes; 138 staticSizes.reserve(opToPad.getRank(opOperand)); 139 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = paddingFunc(rewriter, *opOperand); 152 auto staticTensorType = RankedTensorType::get( 153 staticSizes, getElementTypeOrSelf(opOperand->get())); 154 result = linalg::PadTensorOp::createPadHighOp( 155 staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 LogicalResult 160 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 161 const PaddingValueComputationFunction &paddingFunc, 162 LinalgOp &paddedOp) { 163 Location loc = opToPad->getLoc(); 164 165 // If the op is fully static, it does not need padding. 166 // TODO: there are cases where we may still want to pad to larger sizes. 167 assert(opToPad.hasTensorSemantics() && 168 "expected operation to have tensor semantics"); 169 if (!opToPad.hasDynamicShape()) 170 return success(); 171 172 OpBuilder::InsertionGuard g(rewriter); 173 // Set IP after op because we also take the dims of the original output. 174 rewriter.setInsertionPointAfter(opToPad); 175 // Make a copy of the shaped operands and update it. 176 SmallVector<Value> newOperands; 177 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 178 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 179 Value paddedOperand; 180 // If padding was requested but the shape cannot be bounded statically then 181 // the pattern fails to apply. 182 if (failed(padOperandToSmallestStaticBoundingBox( 183 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 184 return failure(); 185 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 186 } 187 188 SmallVector<SmallVector<Value>> reifiedResultShapes; 189 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 190 .reifyResultShapes(rewriter, reifiedResultShapes))) 191 return failure(); 192 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 193 "expected same number of results"); 194 195 // Clone `opToPad` to operate on the statically padded shapes. 196 auto resultTensorTypes = 197 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 198 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 199 200 // Recover the slice out of the new static results. This keeps the original 201 // linalg op around because it uses the dims of the original results. 202 SmallVector<Value> paddedSubviewResults; 203 paddedSubviewResults.reserve(opToPad->getNumResults()); 204 for (auto en : llvm::enumerate(paddedOp->getResults())) { 205 Value paddedResult = en.value(); 206 int64_t resultNumber = en.index(); 207 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 208 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 209 SmallVector<OpFoldResult> sizes; 210 for (Value v : reifiedResultShapes[resultNumber]) 211 sizes.push_back(v); 212 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 213 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 214 loc, paddedResult, offsets, sizes, strides)); 215 } 216 rewriter.replaceOp(opToPad, paddedSubviewResults); 217 return success(); 218 } 219 220 /// Linalg base tiling pattern. 221 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 222 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 223 LinalgTransformationFilter filter, PatternBenefit benefit) 224 : RewritePattern(opName, benefit, context), filter(filter), 225 options(options) {} 226 227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 228 MLIRContext *context, LinalgTilingOptions options, 229 LinalgTransformationFilter filter, PatternBenefit benefit) 230 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 231 options(options) {} 232 233 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 234 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 235 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 236 if (!linalgOp) 237 return failure(); 238 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 239 return failure(); 240 241 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 242 243 if (!res) 244 return failure(); 245 246 // Setup RAII guard to return properly. 247 LinalgOp paddedOp; 248 LinalgOp tiledOp = res->op; 249 auto guard = llvm::make_scope_exit([&]() { 250 // Return relevant information to derived pattern. 251 result = *res; 252 // Update filter. 253 if (paddedOp) 254 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 255 else 256 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 257 }); 258 259 // Consider padding on the fly only if the op has tensor semantics. 260 if (!options.paddingValueComputationFunction || 261 !linalgOp.hasTensorSemantics()) 262 return success(); 263 264 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 265 // `res.op` is rewritten in static form with padded operands. 266 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 267 options.paddingValueComputationFunction, 268 paddedOp))) { 269 res->op = paddedOp; 270 // Do not perform replacement of `linalgOp`, let the derived patterns 271 // do this as they see fit, from the resulting TiledLinalgOp. 272 return success(); 273 } 274 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 275 return failure(); 276 } 277 278 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 279 if (tiledOp.loops.empty()) 280 return tiledOp.op.getOperation()->getResults(); 281 return tiledOp.loops.front()->getResults(); 282 } 283 284 static ValueRange 285 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 286 if (tiledAndFusedOp.fusedLoops.empty()) 287 return tiledAndFusedOp.op.getOperation()->getResults(); 288 return tiledAndFusedOp.fusedLoops.front()->getResults(); 289 } 290 291 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 292 StringRef opName, MLIRContext *context, 293 const LinalgDependenceGraph &dependenceGraph, 294 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 295 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 296 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 297 : RewritePattern(opName, benefit, context, {}), 298 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 299 fusionOptions(fusionOptions), filter(filter), 300 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 301 302 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 303 Operation *op, PatternRewriter &rewriter) const { 304 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 305 // TODO: remove hasIndexSemantics check once index ops are supported. 306 if (!linalgOp || linalgOp.hasIndexSemantics()) 307 return failure(); 308 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 309 return failure(); 310 311 DenseSet<Operation *> producers; 312 producers.insert(linalgOp); 313 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 314 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 315 // When looking at dependences into, indexingOp is always OpOperand. We 316 // could assert, but continue if this is not the case. 317 if (!operandNumber) 318 continue; 319 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 320 continue; 321 if (isa<LinalgOp>(dependence.getDependentOp())) 322 producers.insert(dependence.getDependentOp()); 323 } 324 325 SmallVector<LinalgOp, 1> fusionOps; 326 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 327 ++it) { 328 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 329 if (producerLinalgOp && producers.count(producerLinalgOp)) 330 fusionOps.push_back(producerLinalgOp); 331 } 332 fusionOps.push_back(linalgOp); 333 334 SmallVector<Value, 4> tileSizes = 335 tilingOptions.tileSizeComputationFunction(rewriter, op); 336 LinalgTilingOptions instanceTilingOptions = tilingOptions; 337 instanceTilingOptions.setTileSizes(tileSizes); 338 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 339 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 340 if (!tiledAndFusedOps) 341 return failure(); 342 343 // Tile the unfused loops; 344 SmallVector<Value, 4> unfusedLoopTileSizes; 345 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 346 for (auto tileSize : enumerate(tileSizes)) { 347 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 348 unfusedLoopTileSizes.push_back(zero); 349 else 350 unfusedLoopTileSizes.push_back(tileSize.value()); 351 } 352 // Tile the loop only if there is a non-zero tile size. 353 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 354 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 355 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 356 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 357 return cst.getValue() != 0; 358 return true; 359 })) { 360 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 361 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 362 Optional<TiledLinalgOp> unfusedTiledOp = 363 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 364 if (!unfusedTiledOp) 365 return failure(); 366 rewriter.replaceOp(tiledAndFusedOps->op, 367 getTiledOpResult(unfusedTiledOp.getValue())); 368 tiledAndFusedOps->op = unfusedTiledOp->op; 369 } 370 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 371 372 filter.replaceLinalgTransformationFilter(rewriter, 373 tiledAndFusedOps->op.getOperation()); 374 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 375 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 376 fusedOp.getOperation()); 377 } 378 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 379 originalOpMarker.replaceLinalgTransformationFilter( 380 rewriter, origProducerOp.getOperation()); 381 } 382 rewriter.updateRootInPlace(op, [&]() { 383 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 384 }); 385 return success(); 386 } 387 388 /// Linalg generic interchange pattern. 389 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 390 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 391 LinalgTransformationFilter filter, PatternBenefit benefit) 392 : OpRewritePattern(context, benefit), filter(filter), 393 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 394 395 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 396 GenericOp genericOp, PatternRewriter &rewriter) const { 397 if (failed(filter.checkAndNotify(rewriter, genericOp))) 398 return failure(); 399 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 400 return failure(); 401 402 // TODO: figure out how this interplays with named ops. In particular this 403 // should break the named op property. 404 rewriter.updateRootInPlace(genericOp, [&]() { 405 interchangeGenericOp(rewriter, genericOp, interchangeVector); 406 // New filter if specified. 407 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 408 }); 409 return success(); 410 } 411 412 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 413 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 414 LinalgTransformationFilter filter, PatternBenefit benefit) 415 : RewritePattern(opName, benefit, context, {}), filter(filter), 416 options(options) {} 417 418 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 419 Operation *op, PatternRewriter &rewriter) const { 420 if (failed(filter.checkAndNotify(rewriter, op))) 421 return failure(); 422 if (failed(promoteSubviewsPrecondition(op, options))) 423 return failure(); 424 425 // TODO: We cannot use root update here. This pattern is creating other ops, 426 // so if the promotion fails, those need to be cleaned up, which doesnt seem 427 // to be happening here. So to fail properly, we should be cloning the op and 428 // deleting the previous op. This needs more investigation. 429 rewriter.startRootUpdate(op); 430 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 431 if (!promotedOp) { 432 rewriter.cancelRootUpdate(op); 433 return op->emitError("subview promotion failed"); 434 } 435 rewriter.finalizeRootUpdate(op); 436 filter.replaceLinalgTransformationFilter(rewriter, op); 437 return success(); 438 } 439 440 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 441 MLIRContext *context, LinalgTransformationFilter filter, 442 PatternBenefit benefit) 443 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 444 445 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 446 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 447 PatternBenefit benefit) 448 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 449 450 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 451 Operation *op, PatternRewriter &rewriter) const { 452 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 453 if (!linalgOp) 454 return failure(); 455 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 456 return failure(); 457 SmallVector<Value> newResults; 458 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 459 return failure(); 460 if (!newResults.empty()) 461 rewriter.replaceOp(op, newResults); 462 else 463 rewriter.eraseOp(op); 464 return success(); 465 } 466 467 LogicalResult mlir::linalg::applyStagedPatterns( 468 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 469 const FrozenRewritePatternSet &stage2Patterns, 470 function_ref<LogicalResult(Operation *)> stage3Lambda) { 471 unsigned iteration = 0; 472 (void)iteration; 473 for (const auto &patterns : stage1Patterns) { 474 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 475 << *op); 476 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 477 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 478 return failure(); 479 } 480 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 481 << *op); 482 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 483 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 484 return failure(); 485 } 486 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 487 << *op); 488 if (stage3Lambda) { 489 if (failed(stage3Lambda(op))) 490 return failure(); 491 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 492 << *op); 493 } 494 } 495 return success(); 496 } 497 498 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 499 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 500 } 501 502 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 503 /// with pad_val) and GenericOp (to copy contents). 504 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 505 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 506 507 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 508 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 509 510 // Bail on non-static shapes. 511 if (!inputShapedType.hasStaticShape()) 512 return failure(); 513 if (!resultShapedType.hasStaticShape()) 514 return failure(); 515 516 // Only support padding with a constant for now, i.e. either: 517 // 1. A BBarg from a different block. 518 // 2. A value defined outside of the current block. 519 Block &block = padOp.region().front(); 520 auto yieldOp = cast<YieldOp>(block.getTerminator()); 521 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 522 Value padValue = yieldOp.values().front(); 523 Operation *definingOp = padValue.getDefiningOp(); 524 if (definingOp && definingOp->getBlock() == &block) 525 return failure(); 526 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 527 return failure(); 528 529 // Create tensor with the padded shape 530 Location loc = padOp.getLoc(); 531 SmallVector<Value> indices(resultShapedType.getRank(), 532 rewriter.create<ConstantIndexOp>(loc, 0)); 533 Value initTensor = rewriter.create<InitTensorOp>( 534 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 535 536 // Initialize tensor with the pad value 537 Value tmpTensor = 538 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 539 540 // Copy original contents into new tensor 541 // Uses linalg.generic, but could be done with tensor.insert_slice 542 SmallVector<AffineExpr, 4> outputExprs; 543 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 544 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 545 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 546 } 547 548 SmallVector<AffineMap, 2> transferMaps = { 549 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 550 AffineMap::get(resultShapedType.getRank(), 551 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 552 553 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 554 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 555 getNParallelLoopsAttrs(resultShapedType.getRank()), 556 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 557 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 558 }); 559 560 return success(); 561 } 562 563 /// Filling `dest` using FillOp constant padding value if possible. 564 /// Otherwise, generate a tensor::GenerateOp. 565 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 566 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 567 const SmallVector<Value> &dynSizes) const { 568 auto padValue = padOp.getConstantPaddingValue(); 569 if (padValue) 570 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 571 572 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 573 auto generateOp = rewriter.create<tensor::GenerateOp>( 574 padOp.getLoc(), padOp.getResultType(), dynSizes); 575 // Copy region to new op. 576 BlockAndValueMapping bvm; 577 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 578 // Rewrite linalg::YieldOp to tensor::YieldOp. 579 OpBuilder::InsertionGuard guard(rewriter); 580 auto yieldOp = 581 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 582 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 583 assert(yieldOp.values().size() == 1); 584 rewriter.setInsertionPoint(yieldOp); 585 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 586 return generateOp; 587 } 588 589 LogicalResult 590 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 591 PatternRewriter &rewriter) const { 592 // Given an OpFoldResult, return an index-typed value. 593 auto getIdxValue = [&](OpFoldResult ofr) { 594 if (auto val = ofr.dyn_cast<Value>()) 595 return val; 596 return rewriter 597 .create<ConstantIndexOp>( 598 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 599 .getResult(); 600 }; 601 602 auto resultType = padOp.getResultType(); 603 // Compute size of InitTensorOp. Any combination of static/dynamic is 604 // supported. 605 SmallVector<Value> dynSizes; 606 SmallVector<int64_t> staticSizes; 607 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 608 if (resultType.isDynamicDim(dim)) { 609 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 610 padOp.source(), dim); 611 // Add low and high padding value. 612 auto plusLow = rewriter.createOrFold<AddIOp>( 613 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 614 auto plusHigh = rewriter.createOrFold<AddIOp>( 615 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 616 dynSizes.push_back(plusHigh); 617 } 618 staticSizes.push_back(resultType.getDimSize(dim)); 619 } 620 621 // Init tensor and fill it with padding. 622 Value init = rewriter.create<InitTensorOp>( 623 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 624 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 625 626 // Try optimize the copy of source. 627 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 628 return success(); 629 630 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 631 // for copying the PadOp source. 632 auto sourceType = padOp.getSourceType(); 633 // Compute size of source of PadTensorOp. 634 SmallVector<OpFoldResult> srcSizes; 635 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 636 if (sourceType.isDynamicDim(dim)) { 637 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 638 padOp.getLoc(), padOp.source(), dim)); 639 } else { 640 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 641 } 642 } 643 // Strides of InsertSliceOp are all 1. 644 SmallVector<OpFoldResult> strides(sourceType.getRank(), 645 rewriter.getIndexAttr(1)); 646 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 647 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 648 649 return success(); 650 } 651 652 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 653 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 654 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 655 if (!padOp) 656 return failure(); 657 // Only unit stride supported. 658 if (!sliceOp.hasUnitStride()) 659 return failure(); 660 661 Operation *tiledPadOp = padOp.getTiledImplementation( 662 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 663 sliceOp.getMixedSizes()); 664 // All shapes are static and the data source is actually used. Rewrite into 665 // pad_tensor(subtensor(x)). 666 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 667 return success(); 668 } 669