1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 48 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 53 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 54 Optional<Identifier> replacement) 55 : filters(), 56 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 57 replacement(replacement) { 58 if (f) 59 filters.push_back(f); 60 } 61 62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 63 PatternRewriter &rewriter, Operation *op) const { 64 if (llvm::any_of(filters, 65 [&](const FilterFunction &f) { return failed(f(op)); })) 66 return failure(); 67 68 auto attr = op->template getAttrOfType<StringAttr>( 69 LinalgTransforms::kLinalgTransformMarker); 70 71 if (!attr) { 72 // 1. Has no filter case and matchDisjunction is empty. 73 if (matchDisjunction.empty()) 74 return success(); 75 76 // 2. Has no filter but was expecting a filter. 77 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 78 diag << " does not have any filter from list: "; 79 interleaveComma(matchDisjunction, diag); 80 }); 81 } 82 83 // 4. Match explicit filter. 84 for (auto filter : matchDisjunction) 85 if (attr.getValue() == filter) 86 return success(); 87 88 // 5. Fail to match. 89 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 90 diag << " does not have any filter from list: "; 91 interleaveComma(matchDisjunction, diag); 92 }); 93 } 94 95 void mlir::linalg::LinalgTransformationFilter:: 96 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 97 Operation *op) const { 98 if (replacement.hasValue()) 99 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 100 rewriter.getStringAttr(replacement.getValue().strref())); 101 else 102 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getContext())); 104 } 105 106 LinalgTilingOptions & 107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 108 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 109 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 110 OpBuilder::InsertionGuard guard(b); 111 b.setInsertionPointToStart( 112 &op->getParentOfType<FuncOp>().getBody().front()); 113 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 114 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 115 return v; 116 })); 117 }; 118 return *this; 119 } 120 121 /// Try to compute a static bounding box for `operand` 122 /// Return success if either: 123 /// 1. The operand is already statically shaped, `result` is left unchanged. 124 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 125 /// created PadTensorOp. 126 /// Return failure if the operand cannot be padded to a static shape. 127 static LogicalResult padOperandToSmallestStaticBoundingBox( 128 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 129 const PaddingValueComputationFunction &paddingFunc, Value &result) { 130 // Already static shape, no need to pad. 131 if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) 132 return success(); 133 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 134 // Not a slice op, cannot construct a static bounding box. 135 if (!sliceOp) 136 return failure(); 137 SmallVector<int64_t> staticSizes; 138 staticSizes.reserve(opToPad.getRank(opOperand)); 139 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = paddingFunc(rewriter, *opOperand); 152 auto staticTensorType = RankedTensorType::get( 153 staticSizes, getElementTypeOrSelf(opOperand->get())); 154 result = linalg::PadTensorOp::createPadHighOp( 155 staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 LogicalResult 160 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 161 const PaddingValueComputationFunction &paddingFunc, 162 LinalgOp &paddedOp) { 163 Location loc = opToPad->getLoc(); 164 165 // If the op is fully static, it does not need padding. 166 // TODO: there are cases where we may still want to pad to larger sizes. 167 assert(opToPad.hasTensorSemantics() && 168 "expected operation to have tensor semantics"); 169 if (!opToPad.hasDynamicShape()) 170 return success(); 171 172 OpBuilder::InsertionGuard g(rewriter); 173 // Set IP after op because we also take the dims of the original output. 174 rewriter.setInsertionPointAfter(opToPad); 175 // Make a copy of the shaped operands and update it. 176 SmallVector<Value> newOperands; 177 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 178 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 179 Value paddedOperand; 180 // If padding was requested but the shape cannot be bounded statically then 181 // the pattern fails to apply. 182 if (failed(padOperandToSmallestStaticBoundingBox( 183 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 184 return failure(); 185 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 186 } 187 188 // Clone `opToPad` to operate on the statically padded shapes. 189 auto resultTensorTypes = 190 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 191 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 192 193 // Recover the slice out of the new static results. This keeps the original 194 // linalg op around because it uses the dims of the original results. 195 // This later folds away. 196 SmallVector<Value> paddedSubviewResults; 197 paddedSubviewResults.reserve(opToPad->getNumResults()); 198 SetVector<Operation *> newUsersOfOpToPad; 199 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 200 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 201 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 202 auto sizes = llvm::to_vector<4>(llvm::map_range( 203 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 204 auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d); 205 newUsersOfOpToPad.insert(dimOp); 206 return dimOp.getResult(); 207 })); 208 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 209 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 210 loc, std::get<1>(it), offsets, sizes, strides)); 211 } 212 // Replace the transient `opToPad` locally, except for uses that we just 213 // created for the purpose of extracting the dims. 214 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 215 return !newUsersOfOpToPad.contains(opOp.getOwner()); 216 }); 217 return success(); 218 } 219 220 /// Linalg base tiling pattern. 221 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 222 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 223 LinalgTransformationFilter filter, PatternBenefit benefit) 224 : RewritePattern(opName, benefit, context), filter(filter), 225 options(options) {} 226 227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 228 MLIRContext *context, LinalgTilingOptions options, 229 LinalgTransformationFilter filter, PatternBenefit benefit) 230 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 231 options(options) {} 232 233 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 234 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 235 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 236 if (!linalgOp) 237 return failure(); 238 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 239 return failure(); 240 241 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 242 243 if (!res) 244 return failure(); 245 246 // Setup RAII guard to return properly. 247 LinalgOp tiledOp = res->op; 248 auto guard = llvm::make_scope_exit([&]() { 249 // Return relevant information to derived pattern. 250 result = *res; 251 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 252 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 253 if (tiledOp != res->op) 254 filter.replaceLinalgTransformationFilter(rewriter, res->op); 255 }); 256 257 // Consider padding on the fly only if the op has tensor semantics. 258 if (!options.paddingValueComputationFunction || 259 !linalgOp.hasTensorSemantics()) 260 return success(); 261 262 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 263 // `res.op` is rewritten in static form with padded operands. 264 LinalgOp paddedOp; 265 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 266 options.paddingValueComputationFunction, 267 paddedOp))) { 268 res->op = paddedOp; 269 // Do not perform replacement of `linalgOp`, let the derived patterns 270 // do this as they see fit, from the resulting TiledLinalgOp. 271 return success(); 272 } 273 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 274 return failure(); 275 } 276 277 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 278 if (tiledOp.loops.empty()) 279 return tiledOp.op.getOperation()->getResults(); 280 return tiledOp.loops.front()->getResults(); 281 } 282 283 static ValueRange 284 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 285 if (tiledAndFusedOp.fusedLoops.empty()) 286 return tiledAndFusedOp.op.getOperation()->getResults(); 287 return tiledAndFusedOp.fusedLoops.front()->getResults(); 288 } 289 290 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 291 StringRef opName, MLIRContext *context, 292 const LinalgDependenceGraph &dependenceGraph, 293 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 294 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 295 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 296 : RewritePattern(opName, benefit, context, {}), 297 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 298 fusionOptions(fusionOptions), filter(filter), 299 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 300 301 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 302 Operation *op, PatternRewriter &rewriter) const { 303 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 304 // TODO: remove hasIndexSemantics check once index ops are supported. 305 if (!linalgOp || linalgOp.hasIndexSemantics()) 306 return failure(); 307 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 308 return failure(); 309 310 DenseSet<Operation *> producers; 311 producers.insert(linalgOp); 312 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 313 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 314 // When looking at dependences into, indexingOp is always OpOperand. We 315 // could assert, but continue if this is not the case. 316 if (!operandNumber) 317 continue; 318 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 319 continue; 320 if (isa<LinalgOp>(dependence.getDependentOp())) 321 producers.insert(dependence.getDependentOp()); 322 } 323 324 SmallVector<LinalgOp, 1> fusionOps; 325 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 326 ++it) { 327 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 328 if (producerLinalgOp && producers.count(producerLinalgOp)) 329 fusionOps.push_back(producerLinalgOp); 330 } 331 fusionOps.push_back(linalgOp); 332 333 SmallVector<Value, 4> tileSizes = 334 tilingOptions.tileSizeComputationFunction(rewriter, op); 335 LinalgTilingOptions instanceTilingOptions = tilingOptions; 336 instanceTilingOptions.setTileSizes(tileSizes); 337 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 338 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 339 if (!tiledAndFusedOps) 340 return failure(); 341 342 // Tile the unfused loops; 343 SmallVector<Value, 4> unfusedLoopTileSizes; 344 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 345 for (auto tileSize : enumerate(tileSizes)) { 346 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 347 unfusedLoopTileSizes.push_back(zero); 348 else 349 unfusedLoopTileSizes.push_back(tileSize.value()); 350 } 351 // Tile the loop only if there is a non-zero tile size. 352 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 353 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 354 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 355 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 356 return cst.getValue() != 0; 357 return true; 358 })) { 359 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 360 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 361 Optional<TiledLinalgOp> unfusedTiledOp = 362 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 363 if (!unfusedTiledOp) 364 return failure(); 365 rewriter.replaceOp(tiledAndFusedOps->op, 366 getTiledOpResult(unfusedTiledOp.getValue())); 367 tiledAndFusedOps->op = unfusedTiledOp->op; 368 } 369 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 370 371 filter.replaceLinalgTransformationFilter(rewriter, 372 tiledAndFusedOps->op.getOperation()); 373 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 374 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 375 fusedOp.getOperation()); 376 } 377 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 378 originalOpMarker.replaceLinalgTransformationFilter( 379 rewriter, origProducerOp.getOperation()); 380 } 381 rewriter.updateRootInPlace(op, [&]() { 382 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 383 }); 384 return success(); 385 } 386 387 /// Linalg generic interchange pattern. 388 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 389 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 390 LinalgTransformationFilter filter, PatternBenefit benefit) 391 : OpRewritePattern(context, benefit), filter(filter), 392 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 393 394 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 395 GenericOp genericOp, PatternRewriter &rewriter) const { 396 if (failed(filter.checkAndNotify(rewriter, genericOp))) 397 return failure(); 398 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 399 return failure(); 400 401 // TODO: figure out how this interplays with named ops. In particular this 402 // should break the named op property. 403 rewriter.updateRootInPlace(genericOp, [&]() { 404 interchangeGenericOp(rewriter, genericOp, interchangeVector); 405 // New filter if specified. 406 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 407 }); 408 return success(); 409 } 410 411 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 412 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 413 LinalgTransformationFilter filter, PatternBenefit benefit) 414 : RewritePattern(opName, benefit, context, {}), filter(filter), 415 options(options) {} 416 417 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 418 Operation *op, PatternRewriter &rewriter) const { 419 if (failed(filter.checkAndNotify(rewriter, op))) 420 return failure(); 421 if (failed(promoteSubviewsPrecondition(op, options))) 422 return failure(); 423 424 // TODO: We cannot use root update here. This pattern is creating other ops, 425 // so if the promotion fails, those need to be cleaned up, which doesnt seem 426 // to be happening here. So to fail properly, we should be cloning the op and 427 // deleting the previous op. This needs more investigation. 428 rewriter.startRootUpdate(op); 429 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 430 if (!promotedOp) { 431 rewriter.cancelRootUpdate(op); 432 return op->emitError("subview promotion failed"); 433 } 434 rewriter.finalizeRootUpdate(op); 435 filter.replaceLinalgTransformationFilter(rewriter, op); 436 return success(); 437 } 438 439 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 440 MLIRContext *context, LinalgTransformationFilter filter, 441 PatternBenefit benefit) 442 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 443 444 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 445 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 446 PatternBenefit benefit) 447 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 448 449 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 450 Operation *op, PatternRewriter &rewriter) const { 451 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 452 if (!linalgOp) 453 return failure(); 454 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 455 return failure(); 456 SmallVector<Value> newResults; 457 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 458 return failure(); 459 if (!newResults.empty()) 460 rewriter.replaceOp(op, newResults); 461 else 462 rewriter.eraseOp(op); 463 return success(); 464 } 465 466 LogicalResult mlir::linalg::applyStagedPatterns( 467 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 468 const FrozenRewritePatternSet &stage2Patterns, 469 function_ref<LogicalResult(Operation *)> stage3Lambda) { 470 unsigned iteration = 0; 471 (void)iteration; 472 for (const auto &patterns : stage1Patterns) { 473 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 474 << *op); 475 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 476 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 477 return failure(); 478 } 479 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 480 << *op); 481 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 482 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 483 return failure(); 484 } 485 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 486 << *op); 487 if (stage3Lambda) { 488 if (failed(stage3Lambda(op))) 489 return failure(); 490 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 491 << *op); 492 } 493 } 494 return success(); 495 } 496 497 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 498 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 499 } 500 501 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 502 /// with pad_val) and GenericOp (to copy contents). 503 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 504 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 505 506 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 507 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 508 509 // Bail on non-static shapes. 510 if (!inputShapedType.hasStaticShape()) 511 return failure(); 512 if (!resultShapedType.hasStaticShape()) 513 return failure(); 514 515 // Only support padding with a constant for now, i.e. either: 516 // 1. A BBarg from a different block. 517 // 2. A value defined outside of the current block. 518 Block &block = padOp.region().front(); 519 auto yieldOp = cast<YieldOp>(block.getTerminator()); 520 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 521 Value padValue = yieldOp.values().front(); 522 Operation *definingOp = padValue.getDefiningOp(); 523 if (definingOp && definingOp->getBlock() == &block) 524 return failure(); 525 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 526 return failure(); 527 528 // Create tensor with the padded shape 529 Location loc = padOp.getLoc(); 530 SmallVector<Value> indices(resultShapedType.getRank(), 531 rewriter.create<ConstantIndexOp>(loc, 0)); 532 Value initTensor = rewriter.create<InitTensorOp>( 533 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 534 535 // Initialize tensor with the pad value 536 Value tmpTensor = 537 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 538 539 // Copy original contents into new tensor 540 // Uses linalg.generic, but could be done with tensor.insert_slice 541 SmallVector<AffineExpr, 4> outputExprs; 542 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 543 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 544 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 545 } 546 547 SmallVector<AffineMap, 2> transferMaps = { 548 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 549 AffineMap::get(resultShapedType.getRank(), 550 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 551 552 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 553 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 554 getNParallelLoopsAttrs(resultShapedType.getRank()), 555 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 556 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 557 }); 558 559 return success(); 560 } 561 562 /// Filling `dest` using FillOp constant padding value if possible. 563 /// Otherwise, generate a tensor::GenerateOp. 564 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 565 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 566 const SmallVector<Value> &dynSizes) const { 567 auto padValue = padOp.getConstantPaddingValue(); 568 if (padValue) 569 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 570 571 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 572 auto generateOp = rewriter.create<tensor::GenerateOp>( 573 padOp.getLoc(), padOp.getResultType(), dynSizes); 574 // Copy region to new op. 575 BlockAndValueMapping bvm; 576 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 577 // Rewrite linalg::YieldOp to tensor::YieldOp. 578 OpBuilder::InsertionGuard guard(rewriter); 579 auto yieldOp = 580 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 581 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 582 assert(yieldOp.values().size() == 1); 583 rewriter.setInsertionPoint(yieldOp); 584 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 585 return generateOp; 586 } 587 588 LogicalResult 589 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 590 PatternRewriter &rewriter) const { 591 // Given an OpFoldResult, return an index-typed value. 592 auto getIdxValue = [&](OpFoldResult ofr) { 593 if (auto val = ofr.dyn_cast<Value>()) 594 return val; 595 return rewriter 596 .create<ConstantIndexOp>( 597 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 598 .getResult(); 599 }; 600 601 auto resultType = padOp.getResultType(); 602 // Compute size of InitTensorOp. Any combination of static/dynamic is 603 // supported. 604 SmallVector<Value> dynSizes; 605 SmallVector<int64_t> staticSizes; 606 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 607 if (resultType.isDynamicDim(dim)) { 608 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 609 padOp.source(), dim); 610 // Add low and high padding value. 611 auto plusLow = rewriter.createOrFold<AddIOp>( 612 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 613 auto plusHigh = rewriter.createOrFold<AddIOp>( 614 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 615 dynSizes.push_back(plusHigh); 616 } 617 staticSizes.push_back(resultType.getDimSize(dim)); 618 } 619 620 // Init tensor and fill it with padding. 621 Value init = rewriter.create<InitTensorOp>( 622 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 623 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 624 625 // Try optimize the copy of source. 626 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 627 return success(); 628 629 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 630 // for copying the PadOp source. 631 auto sourceType = padOp.getSourceType(); 632 // Compute size of source of PadTensorOp. 633 SmallVector<OpFoldResult> srcSizes; 634 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 635 if (sourceType.isDynamicDim(dim)) { 636 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 637 padOp.getLoc(), padOp.source(), dim)); 638 } else { 639 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 640 } 641 } 642 // Strides of InsertSliceOp are all 1. 643 SmallVector<OpFoldResult> strides(sourceType.getRank(), 644 rewriter.getIndexAttr(1)); 645 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 646 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 647 648 return success(); 649 } 650 651 /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute, 652 /// it must be of type Integer. 653 static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) { 654 if (auto val = ofr.dyn_cast<Value>()) 655 return val; 656 auto intVal = getConstantIntValue(ofr); 657 assert(intVal && "expected Value or IntegerAttr"); 658 return builder.create<ConstantIndexOp>(loc, *intVal); 659 } 660 661 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 662 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 663 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 664 if (!padOp) 665 return failure(); 666 // Only unit stride supported. 667 if (!sliceOp.hasUnitStride()) 668 return failure(); 669 // Only constant padding value supported. 670 Value padValue = padOp.getConstantPaddingValue(); 671 if (!padValue) 672 return failure(); 673 674 // Helper variables and functions for various arithmetic operations. These are 675 // used extensively for computing new offset/length and padding values. 676 Location loc = sliceOp.getLoc(); 677 AffineExpr dim0, dim1; 678 bindDims(rewriter.getContext(), dim0, dim1); 679 // Add two integers. 680 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 681 auto add = [&](Value v1, Value v2) { 682 return rewriter.createOrFold<AffineApplyOp>(loc, addMap, 683 ValueRange{v1, v2}); 684 }; 685 // Subtract two integers. 686 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 687 auto sub = [&](Value v1, Value v2) { 688 return rewriter.createOrFold<AffineApplyOp>(loc, subMap, 689 ValueRange{v1, v2}); 690 }; 691 // Take the minimum of two integers. 692 auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 693 auto min = [&](Value v1, Value v2) { 694 return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2}); 695 }; 696 // Take the maximum of two integers. 697 auto max = [&](Value v1, Value v2) { 698 return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2}); 699 }; 700 // Zero index-typed integer. 701 auto zero = rewriter.create<ConstantIndexOp>(loc, 0); 702 703 // Helper function for filling static/dynamic low/high padding indices vectors 704 // of PadTensorOp. 705 auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices, 706 SmallVector<int64_t> &staticIndices) { 707 if (auto constInt = getConstantIntValue(val)) { 708 staticIndices.push_back(*constInt); 709 } else { 710 staticIndices.push_back(ShapedType::kDynamicSize); 711 dynIndices.push_back(val); 712 } 713 }; 714 715 // Compute new offsets, lengths, low padding, high padding. 716 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 717 SmallVector<Value> newLows, newHighs; 718 SmallVector<int64_t> staticNewLows, staticNewHighs; 719 // Set to true if the original data source is not read at all. 720 bool hasZeroLen = false; 721 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 722 // is true if the original data source turns out to be unused at runtime. 723 Value dynHasZeroLenCond; 724 725 int64_t rank = padOp.getSourceType().getRank(); 726 for (unsigned dim = 0; dim < rank; ++dim) { 727 auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]); 728 bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0); 729 auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]); 730 bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0); 731 auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]); 732 auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]); 733 auto srcSize = 734 rewriter.createOrFold<tensor::DimOp>(loc, padOp.source(), dim); 735 736 // The new amount of low padding is `low - offset`. Except for the case 737 // where none of the low padding is read. In that case, the new amount of 738 // low padding is zero. 739 // 740 // Optimization: If low = 0, then newLow = 0. 741 Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 742 appendIndex(newLow, newLows, staticNewLows); 743 744 // Start reading the data from position `offset - low`. Since the original 745 // read may have started in the low padding zone, this value could be 746 // negative. Therefore, start reading from: 747 // 748 // max(offset - low, 0) 749 // 750 // The original read could also have started in the high padding zone. 751 // In that case, set the offset to the end of source tensor. The new 752 // ExtractSliceOp length will be zero in that case. (Effectively reading no 753 // data from the source.) 754 // 755 // Optimization: If low = 0, then the formula can be simplified. 756 Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) 757 : min(offset, srcSize); 758 newOffsets.push_back(getAsOpFoldResult(newOffset)); 759 760 // The original ExtractSliceOp was reading until position `offset + length`. 761 // Therefore, the corresponding position within the source tensor is: 762 // 763 // offset + length - low 764 // 765 // In case the original ExtractSliceOp stopped reading within the low 766 // padding zone, this value can be negative. In that case, the end position 767 // of the read should be zero. (Similar to newOffset.) 768 // 769 // The original read could also have stopped in the high padding zone. 770 // In that case, set the end positition of the read should be the end of the 771 // source tensor. (Similar to newOffset.) 772 // 773 // endLoc = min(max(offset - low + length, 0), srcSize) 774 // 775 // The new ExtractSliceOp length is `endLoc - newOffset`. 776 // 777 // Optimization: If low = 0, then the formula can be simplified. 778 Value endLoc = hasLowPad 779 ? min(max(add(sub(offset, low), length), zero), srcSize) 780 : min(add(offset, length), srcSize); 781 Value newLength = sub(endLoc, newOffset); 782 newLengths.push_back(getAsOpFoldResult(newLength)); 783 784 // Check if newLength is zero. In that case, no SubTensorOp should be 785 // executed. 786 if (auto newLengthInt = getConstantIntValue(newLength)) { 787 hasZeroLen |= *newLengthInt == 0; 788 } else { 789 Value check = rewriter.create<CmpIOp>( 790 loc, CmpIPredicate::eq, newLength, zero); 791 dynHasZeroLenCond = 792 dynHasZeroLenCond 793 ? rewriter.create<OrOp>(loc, check, dynHasZeroLenCond) 794 : check; 795 } 796 797 // The amount of high padding is simply the number of elements remaining, 798 // so that the result has the same length as the original ExtractSliceOp. 799 // As an optimization, if the original high padding is zero, then the new 800 // high padding must also be zero. 801 Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; 802 appendIndex(newHigh, newHighs, staticNewHighs); 803 804 // Only unit stride supported. 805 newStrides.push_back(rewriter.getIndexAttr(1)); 806 } 807 808 // Insert cast to ensure that types match. (May be folded away.) 809 auto castResult = [&](Value val) -> Value { 810 auto castOp = rewriter.create<tensor::CastOp>(loc, sliceOp.getType(), val); 811 return castOp; 812 }; 813 814 // In cases where the original data source is unused: Emit a GenerateOp and 815 // do not generate a SliceOp. (The result shape of the SliceOp would 816 // have a dimension of size 0, the semantics of which is unclear.) 817 auto createGenerateOp = [&]() { 818 // The shape of the GenerateOp is the same as the existing SliceOp. 819 RankedTensorType type = sliceOp.getType(); 820 SmallVector<Value> dynDims; 821 for (unsigned i = 0; i < type.getRank(); ++i) { 822 if (type.isDynamicDim(i)) 823 dynDims.push_back(asValue(rewriter, loc, sliceOp.getMixedSizes()[i])); 824 } 825 826 // Create GenerateOp. 827 auto generateOp = rewriter.create<tensor::GenerateOp>(loc, type, dynDims); 828 829 // Copy region to new op. 830 BlockAndValueMapping bvm; 831 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 832 // Rewrite linalg::YieldOp to tensor::YieldOp. 833 { 834 OpBuilder::InsertionGuard guard(rewriter); 835 auto yieldOp = dyn_cast<linalg::YieldOp>( 836 generateOp.getRegion().front().getTerminator()); 837 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 838 assert(yieldOp.values().size() == 1); 839 rewriter.setInsertionPoint(yieldOp); 840 rewriter.replaceOpWithNewOp<tensor::YieldOp>( 841 yieldOp, yieldOp.values()[0]); 842 } 843 844 return castResult(generateOp); 845 }; 846 847 // Emit a SliceOp and a PadTensorOp. Should not be used in cases where 848 // the result shape of the new SliceOp has a zero dimension. 849 auto createPadTensorOfSubTensor = [&]() { 850 // Create pad_tensor(subtensor(x)). 851 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 852 loc, padOp.source(), newOffsets, newLengths, newStrides); 853 auto newPadTensorOp = rewriter.create<PadTensorOp>( 854 loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs); 855 856 // Copy region to new PadTensorOp. 857 BlockAndValueMapping bvm; 858 padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm); 859 860 // Cast result and return. 861 return castResult(newPadTensorOp); 862 }; 863 864 // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known 865 // that the original data source x is not used. 866 if (hasZeroLen) { 867 rewriter.replaceOp(sliceOp, createGenerateOp()); 868 return success(); 869 } 870 871 // If there are dynamic dimensions: Generate an scf.if check to avoid creating 872 // SliceOps with result dimensions of size 0 at runtime. 873 if (dynHasZeroLenCond) { 874 auto result = rewriter.create<scf::IfOp>( 875 loc, sliceOp.getType(), dynHasZeroLenCond, 876 /*thenBuilder=*/ 877 [&](OpBuilder &b, Location loc) { 878 b.create<scf::YieldOp>(loc, createGenerateOp()); 879 }, 880 /*elseBuilder=*/ 881 [&](OpBuilder &b, Location loc) { 882 b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor()); 883 }); 884 rewriter.replaceOp(sliceOp, result.getResult(0)); 885 return success(); 886 } 887 888 // All shapes are static and the data source is actually used. Rewrite into 889 // pad_tensor(subtensor(x)). 890 rewriter.replaceOp(sliceOp, createPadTensorOfSubTensor()); 891 return success(); 892 } 893