1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/ScopeExit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <type_traits> 32 33 #define DEBUG_TYPE "linalg-transforms" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 48 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 53 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 54 Optional<Identifier> replacement) 55 : filters(), 56 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 57 replacement(replacement) { 58 if (f) 59 filters.push_back(f); 60 } 61 62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 63 PatternRewriter &rewriter, Operation *op) const { 64 if (llvm::any_of(filters, 65 [&](const FilterFunction &f) { return failed(f(op)); })) 66 return failure(); 67 68 auto attr = op->template getAttrOfType<StringAttr>( 69 LinalgTransforms::kLinalgTransformMarker); 70 71 if (!attr) { 72 // 1. Has no filter case and matchDisjunction is empty. 73 if (matchDisjunction.empty()) 74 return success(); 75 76 // 2. Has no filter but was expecting a filter. 77 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 78 diag << " does not have any filter from list: "; 79 interleaveComma(matchDisjunction, diag); 80 }); 81 } 82 83 // 4. Match explicit filter. 84 for (auto filter : matchDisjunction) 85 if (attr.getValue() == filter) 86 return success(); 87 88 // 5. Fail to match. 89 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 90 diag << " does not have any filter from list: "; 91 interleaveComma(matchDisjunction, diag); 92 }); 93 } 94 95 void mlir::linalg::LinalgTransformationFilter:: 96 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 97 Operation *op) const { 98 if (replacement.hasValue()) 99 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 100 rewriter.getStringAttr(replacement.getValue().strref())); 101 else 102 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getContext())); 104 } 105 106 LinalgTilingOptions & 107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 108 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 109 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 110 OpBuilder::InsertionGuard guard(b); 111 b.setInsertionPointToStart( 112 &op->getParentOfType<FuncOp>().getBody().front()); 113 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 114 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 115 return v; 116 })); 117 }; 118 return *this; 119 } 120 121 /// Try to compute a static bounding box for `operand` 122 /// Return success if either: 123 /// 1. The operand is already statically shaped, `result` is left unchanged. 124 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 125 /// created PadTensorOp. 126 /// Return failure if the operand cannot be padded to a static shape. 127 static LogicalResult padOperandToSmallestStaticBoundingBox( 128 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 129 const LinalgTilingOptions &options, Value &result) { 130 // Already static shape, no need to pad. 131 if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) 132 return success(); 133 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 134 // Not a slice op, cannot construct a static bounding box. 135 if (!sliceOp) 136 return failure(); 137 SmallVector<int64_t> staticSizes; 138 staticSizes.reserve(opToPad.getRank(opOperand)); 139 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 140 for (auto size : shapedOp.getMixedSizes()) { 141 auto indexAttr = size.is<Attribute>() 142 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 143 : linalg::getSmallestBoundingIndex(size.get<Value>()); 144 // SmallestBoundingIndex must exist for all sizes. 145 // For now return an error if we can't find it. 146 if (!indexAttr) 147 return rewriter.notifyMatchFailure( 148 opToPad, "No constant bounding box can be found for padding"); 149 staticSizes.push_back(indexAttr.getInt()); 150 } 151 Value pad = options.paddingValueComputationFunction(rewriter, *opOperand); 152 auto staticTensorType = RankedTensorType::get( 153 staticSizes, getElementTypeOrSelf(opOperand->get())); 154 result = linalg::PadTensorOp::createPadHighOp( 155 staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); 156 return success(); 157 } 158 159 // Try to create a static bounding box around each operand of `res.op`. 160 // If successful, `res.op` is rewritten in static form with padded operands. 161 // `res.op` is updated to the cloned static form of the op on success. 162 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 163 TiledLinalgOp &res, 164 const LinalgTilingOptions &options) { 165 LinalgOp opToPad = res.op; 166 Location loc = opToPad->getLoc(); 167 168 // If the op is fully static, it does not need padding. 169 // TODO: there are cases where we may still want to pad to larger sizes. 170 assert(opToPad.hasTensorSemantics() && 171 "expected operation to have tensor semantics"); 172 if (!opToPad.hasDynamicShape()) 173 return success(); 174 175 OpBuilder::InsertionGuard g(rewriter); 176 // Set IP after op because we also take the dims of the original output. 177 rewriter.setInsertionPointAfter(opToPad); 178 // Make a copy of the shaped operands and update it. 179 SmallVector<Value> newOperands; 180 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 181 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 182 Value paddedOperand; 183 // If padding was requested but the shape cannot be bounded statically then 184 // the pattern fails to apply. 185 if (failed(padOperandToSmallestStaticBoundingBox( 186 rewriter, opToPad, opOperand, options, paddedOperand))) 187 return failure(); 188 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 189 } 190 191 // Clone `opToPad` to operate on the statically padded shapes. 192 auto resultTensorTypes = 193 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 194 linalg::LinalgOp paddedOp = 195 opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 196 197 // Recover the slice out of the new static results. This keeps the original 198 // linalg op around because it uses the dims of the original results. 199 // This later folds away. 200 SmallVector<Value> paddedSubviewResults; 201 paddedSubviewResults.reserve(opToPad->getNumResults()); 202 SetVector<Operation *> newUsersOfOpToPad; 203 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 204 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 205 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 206 auto sizes = llvm::to_vector<4>(llvm::map_range( 207 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 208 auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d); 209 newUsersOfOpToPad.insert(dimOp); 210 return dimOp.getResult(); 211 })); 212 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 213 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 214 loc, std::get<1>(it), offsets, sizes, strides)); 215 } 216 // Replace the transient `opToPad` locally, except for uses that we just 217 // created for the purpose of extracting the dims. 218 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 219 return !newUsersOfOpToPad.contains(opOp.getOwner()); 220 }); 221 222 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 223 return success(); 224 } 225 226 /// Linalg base tiling pattern. 227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 228 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 229 LinalgTransformationFilter filter, PatternBenefit benefit) 230 : RewritePattern(opName, benefit, context), filter(filter), 231 options(options) {} 232 233 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 234 MLIRContext *context, LinalgTilingOptions options, 235 LinalgTransformationFilter filter, PatternBenefit benefit) 236 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 237 options(options) {} 238 239 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 240 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 241 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 242 if (!linalgOp) 243 return failure(); 244 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 245 return failure(); 246 247 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 248 249 if (!res) 250 return failure(); 251 252 // Setup RAII guard to return properly. 253 LinalgOp tiledOp = res->op; 254 auto guard = llvm::make_scope_exit([&]() { 255 // Return relevant information to derived pattern. 256 result = *res; 257 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 258 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 259 if (tiledOp != res->op) 260 filter.replaceLinalgTransformationFilter(rewriter, res->op); 261 }); 262 263 // Consider padding on the fly only if the op has tensor semantics. 264 if (!options.paddingValueComputationFunction || 265 !linalgOp.hasTensorSemantics()) 266 return success(); 267 268 // Try to pad on the fly by rewriting res->op as a padded op. 269 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 270 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 271 return failure(); 272 } 273 274 // Do not perform replacement of `linalgOp`, let the derived patterns 275 // do this as they see fit, from the resulting TiledLinalgOp. 276 return success(); 277 } 278 279 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 280 if (tiledOp.loops.empty()) 281 return tiledOp.op.getOperation()->getResults(); 282 return tiledOp.loops.front()->getResults(); 283 } 284 285 static ValueRange 286 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 287 if (tiledAndFusedOp.fusedLoops.empty()) 288 return tiledAndFusedOp.op.getOperation()->getResults(); 289 return tiledAndFusedOp.fusedLoops.front()->getResults(); 290 } 291 292 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 293 StringRef opName, MLIRContext *context, 294 const LinalgDependenceGraph &dependenceGraph, 295 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 296 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 297 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 298 : RewritePattern(opName, benefit, context, {}), 299 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 300 fusionOptions(fusionOptions), filter(filter), 301 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 302 303 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 304 Operation *op, PatternRewriter &rewriter) const { 305 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 306 // TODO: remove hasIndexSemantics check once index ops are supported. 307 if (!linalgOp || linalgOp.hasIndexSemantics()) 308 return failure(); 309 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 310 return failure(); 311 312 DenseSet<Operation *> producers; 313 producers.insert(linalgOp); 314 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 315 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 316 // When looking at dependences into, indexingOp is always OpOperand. We 317 // could assert, but continue if this is not the case. 318 if (!operandNumber) 319 continue; 320 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 321 continue; 322 if (isa<LinalgOp>(dependence.getDependentOp())) 323 producers.insert(dependence.getDependentOp()); 324 } 325 326 SmallVector<LinalgOp, 1> fusionOps; 327 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 328 ++it) { 329 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 330 if (producerLinalgOp && producers.count(producerLinalgOp)) 331 fusionOps.push_back(producerLinalgOp); 332 } 333 fusionOps.push_back(linalgOp); 334 335 SmallVector<Value, 4> tileSizes = 336 tilingOptions.tileSizeComputationFunction(rewriter, op); 337 LinalgTilingOptions instanceTilingOptions = tilingOptions; 338 instanceTilingOptions.setTileSizes(tileSizes); 339 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 340 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 341 if (!tiledAndFusedOps) 342 return failure(); 343 344 // Tile the unfused loops; 345 SmallVector<Value, 4> unfusedLoopTileSizes; 346 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 347 for (auto tileSize : enumerate(tileSizes)) { 348 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 349 unfusedLoopTileSizes.push_back(zero); 350 else 351 unfusedLoopTileSizes.push_back(tileSize.value()); 352 } 353 // Tile the loop only if there is a non-zero tile size. 354 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 355 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 356 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 357 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 358 return cst.getValue() != 0; 359 return true; 360 })) { 361 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 362 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 363 Optional<TiledLinalgOp> unfusedTiledOp = 364 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 365 if (!unfusedTiledOp) 366 return failure(); 367 rewriter.replaceOp(tiledAndFusedOps->op, 368 getTiledOpResult(unfusedTiledOp.getValue())); 369 tiledAndFusedOps->op = unfusedTiledOp->op; 370 } 371 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 372 373 filter.replaceLinalgTransformationFilter(rewriter, 374 tiledAndFusedOps->op.getOperation()); 375 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 376 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 377 fusedOp.getOperation()); 378 } 379 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 380 originalOpMarker.replaceLinalgTransformationFilter( 381 rewriter, origProducerOp.getOperation()); 382 } 383 rewriter.updateRootInPlace(op, [&]() { 384 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 385 }); 386 return success(); 387 } 388 389 /// Linalg generic interchange pattern. 390 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 391 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 392 LinalgTransformationFilter filter, PatternBenefit benefit) 393 : OpRewritePattern(context, benefit), filter(filter), 394 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 395 396 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 397 GenericOp genericOp, PatternRewriter &rewriter) const { 398 if (failed(filter.checkAndNotify(rewriter, genericOp))) 399 return failure(); 400 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 401 return failure(); 402 403 // TODO: figure out how this interplays with named ops. In particular this 404 // should break the named op property. 405 rewriter.updateRootInPlace(genericOp, [&]() { 406 interchangeGenericOp(rewriter, genericOp, interchangeVector); 407 // New filter if specified. 408 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 409 }); 410 return success(); 411 } 412 413 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 414 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 415 LinalgTransformationFilter filter, PatternBenefit benefit) 416 : RewritePattern(opName, benefit, context, {}), filter(filter), 417 options(options) {} 418 419 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 420 Operation *op, PatternRewriter &rewriter) const { 421 if (failed(filter.checkAndNotify(rewriter, op))) 422 return failure(); 423 if (failed(promoteSubviewsPrecondition(op, options))) 424 return failure(); 425 426 // TODO: We cannot use root update here. This pattern is creating other ops, 427 // so if the promotion fails, those need to be cleaned up, which doesnt seem 428 // to be happening here. So to fail properly, we should be cloning the op and 429 // deleting the previous op. This needs more investigation. 430 rewriter.startRootUpdate(op); 431 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 432 if (!promotedOp) { 433 rewriter.cancelRootUpdate(op); 434 return op->emitError("subview promotion failed"); 435 } 436 rewriter.finalizeRootUpdate(op); 437 filter.replaceLinalgTransformationFilter(rewriter, op); 438 return success(); 439 } 440 441 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 442 MLIRContext *context, LinalgTransformationFilter filter, 443 PatternBenefit benefit) 444 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 445 446 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 447 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 448 PatternBenefit benefit) 449 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 450 451 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 452 Operation *op, PatternRewriter &rewriter) const { 453 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 454 if (!linalgOp) 455 return failure(); 456 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 457 return failure(); 458 SmallVector<Value> newResults; 459 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 460 return failure(); 461 if (!newResults.empty()) 462 rewriter.replaceOp(op, newResults); 463 else 464 rewriter.eraseOp(op); 465 return success(); 466 } 467 468 LogicalResult mlir::linalg::applyStagedPatterns( 469 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 470 const FrozenRewritePatternSet &stage2Patterns, 471 function_ref<LogicalResult(Operation *)> stage3Lambda) { 472 unsigned iteration = 0; 473 (void)iteration; 474 for (const auto &patterns : stage1Patterns) { 475 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 476 << *op); 477 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 478 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 479 return failure(); 480 } 481 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 482 << *op); 483 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 484 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 485 return failure(); 486 } 487 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 488 << *op); 489 if (stage3Lambda) { 490 if (failed(stage3Lambda(op))) 491 return failure(); 492 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 493 << *op); 494 } 495 } 496 return success(); 497 } 498 499 /// Traverse the `dims` and substitute known min or max expressions returned by 500 /// the lambda |getMinMaxExpr|. 501 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 502 SmallVectorImpl<Value> &symbols, 503 GetMinMaxExprFn getMinMaxExpr) { 504 auto exprs = llvm::to_vector<4>(map.getResults()); 505 for (AffineExpr &expr : exprs) { 506 bool substituted = true; 507 while (substituted) { 508 substituted = false; 509 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 510 Value dim = dims[dimIdx]; 511 auto minMax = getMinMaxExpr(dim, dims, symbols); 512 if (!minMax) 513 continue; 514 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 515 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 516 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 517 // Substitute occurrences of `dimExpr` by either the min expression or 518 // the max expression depending on whether the value is used with a 519 // positive or negative coefficient. 520 AffineExpr substitutedExpr = 521 substWithMin(expr, dimExpr, minMax->first, minMax->second); 522 LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n"); 523 substituted = (substitutedExpr != expr); 524 expr = substitutedExpr; 525 } 526 } 527 528 // Cleanup and simplify the results. 529 // This needs to happen outside of the loop iterating on dims.size() since 530 // it modifies dims. 531 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 532 operands.append(symbols.begin(), symbols.end()); 533 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 534 exprs.front().getContext()); 535 536 LLVM_DEBUG({ 537 DBGS() << "Map to simplify: " << map << "\n"; 538 DBGS() << "Operands:\n"; 539 for (Value v : operands) 540 DBGS() << v << "\n"; 541 }); 542 543 // Pull in affine.apply operations and compose them fully into the 544 // result. 545 fullyComposeAffineMapAndOperands(&map, &operands); 546 canonicalizeMapAndOperands(&map, &operands); 547 map = simplifyAffineMap(map); 548 // Assign the results. 549 exprs.assign(map.getResults().begin(), map.getResults().end()); 550 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 551 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 552 553 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 554 } 555 556 assert(!exprs.empty() && "Unexpected empty exprs"); 557 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 558 } 559 560 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute 561 /// dimensions with known range by new expressions involving the min or max 562 /// expression: 563 /// - If the AffineDimExpr mapped to a known value has a positive sign, it 564 /// is replaced by the min expression. 565 /// - If the AffineDimExpr mapped to a known value has a negative sign, it is 566 /// replaced by the max expression. 567 /// All known values are iteratively replaced. 568 /// This is used as an intermediate step in computing bounding boxes and 569 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have 570 /// positive values (positive orthant assumptions). 571 /// Return a new AffineMap, dims and symbols that have been canonicalized and 572 /// simplified. 573 AffineMapAndOperands 574 mlir::linalg::substituteMin(AffineMinOp affineMinOp, 575 GetMinMaxExprFn getMinMaxExpr) { 576 AffineMapAndOperands res{affineMinOp.getAffineMap(), 577 SmallVector<Value>(affineMinOp.getDimOperands()), 578 SmallVector<Value>(affineMinOp.getSymbolOperands())}; 579 res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, 580 getMinMaxExpr); 581 return res; 582 } 583 584 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( 585 AffineMinOp minOp, PatternRewriter &rewriter) const { 586 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 587 << "\n"); 588 589 auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn); 590 AffineMap map = affineMapAndOperands.map; 591 592 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 593 594 // Check whether any of the expressions, when subtracted from all other 595 // expressions, produces only >= 0 constants. If so, it is the min. 596 for (auto e : minOp.getAffineMap().getResults()) { 597 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 598 if (!e.isSymbolicOrConstant()) 599 continue; 600 601 auto isNonPositive = [](AffineExpr e) { 602 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 603 return cst.getValue() < 0; 604 return true; 605 }; 606 607 // Build the subMap and check everything is statically known to be 608 // positive. 609 SmallVector<AffineExpr, 4> subExprs; 610 subExprs.reserve(map.getNumResults()); 611 for (auto ee : map.getResults()) 612 subExprs.push_back(ee - e); 613 MLIRContext *ctx = minOp.getContext(); 614 AffineMap subMap = simplifyAffineMap( 615 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 616 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 617 if (llvm::any_of(subMap.getResults(), isNonPositive)) 618 continue; 619 620 // Static min found. 621 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 622 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 623 } else { 624 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 625 SmallVector<Value> resultOperands = affineMapAndOperands.dims; 626 llvm::append_range(resultOperands, affineMapAndOperands.symbols); 627 canonicalizeMapAndOperands(&resultMap, &resultOperands); 628 resultMap = simplifyAffineMap(resultMap); 629 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 630 resultOperands); 631 } 632 return success(); 633 } 634 635 return failure(); 636 } 637 638 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 639 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 640 } 641 642 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 643 /// with pad_val) and GenericOp (to copy contents). 644 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 645 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 646 647 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 648 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 649 650 // Bail on non-static shapes. 651 if (!inputShapedType.hasStaticShape()) 652 return failure(); 653 if (!resultShapedType.hasStaticShape()) 654 return failure(); 655 656 // Only support padding with a constant for now, i.e. either: 657 // 1. A BBarg from a different block. 658 // 2. A value defined outside of the current block. 659 Block &block = padOp.region().front(); 660 auto yieldOp = cast<YieldOp>(block.getTerminator()); 661 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 662 Value padValue = yieldOp.values().front(); 663 Operation *definingOp = padValue.getDefiningOp(); 664 if (definingOp && definingOp->getBlock() == &block) 665 return failure(); 666 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 667 return failure(); 668 669 // Create tensor with the padded shape 670 Location loc = padOp.getLoc(); 671 SmallVector<Value> indices(resultShapedType.getRank(), 672 rewriter.create<ConstantIndexOp>(loc, 0)); 673 Value initTensor = rewriter.create<InitTensorOp>( 674 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 675 676 // Initialize tensor with the pad value 677 Value tmpTensor = 678 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 679 680 // Copy original contents into new tensor 681 // Uses linalg.generic, but could be done with tensor.insert_slice 682 SmallVector<AffineExpr, 4> outputExprs; 683 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 684 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 685 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 686 } 687 688 SmallVector<AffineMap, 2> transferMaps = { 689 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 690 AffineMap::get(resultShapedType.getRank(), 691 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 692 693 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 694 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 695 getNParallelLoopsAttrs(resultShapedType.getRank()), 696 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 697 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 698 }); 699 700 return success(); 701 } 702 703 /// Filling `dest` using FillOp constant padding value if possible. 704 /// Otherwise, generate a tensor::GenerateOp. 705 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 706 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 707 const SmallVector<Value> &dynSizes) const { 708 auto padValue = padOp.getConstantPaddingValue(); 709 if (padValue) 710 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 711 712 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 713 auto generateOp = rewriter.create<tensor::GenerateOp>( 714 padOp.getLoc(), padOp.getResultType(), dynSizes); 715 // Copy region to new op. 716 BlockAndValueMapping bvm; 717 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 718 // Rewrite linalg::YieldOp to tensor::YieldOp. 719 OpBuilder::InsertionGuard guard(rewriter); 720 auto yieldOp = 721 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 722 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 723 assert(yieldOp.values().size() == 1); 724 rewriter.setInsertionPoint(yieldOp); 725 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 726 return generateOp; 727 } 728 729 LogicalResult 730 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 731 PatternRewriter &rewriter) const { 732 // Given an OpFoldResult, return an index-typed value. 733 auto getIdxValue = [&](OpFoldResult ofr) { 734 if (auto val = ofr.dyn_cast<Value>()) 735 return val; 736 return rewriter 737 .create<ConstantIndexOp>( 738 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 739 .getResult(); 740 }; 741 742 auto resultType = padOp.getResultType(); 743 // Compute size of InitTensorOp. Any combination of static/dynamic is 744 // supported. 745 SmallVector<Value> dynSizes; 746 SmallVector<int64_t> staticSizes; 747 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 748 if (resultType.isDynamicDim(dim)) { 749 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 750 padOp.source(), dim); 751 // Add low and high padding value. 752 auto plusLow = rewriter.createOrFold<AddIOp>( 753 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 754 auto plusHigh = rewriter.createOrFold<AddIOp>( 755 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 756 dynSizes.push_back(plusHigh); 757 } 758 staticSizes.push_back(resultType.getDimSize(dim)); 759 } 760 761 // Init tensor and fill it with padding. 762 Value init = rewriter.create<InitTensorOp>( 763 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 764 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 765 766 // Try optimize the copy of source. 767 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 768 return success(); 769 770 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 771 // for copying the PadOp source. 772 auto sourceType = padOp.getSourceType(); 773 // Compute size of source of PadTensorOp. 774 SmallVector<OpFoldResult> srcSizes; 775 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 776 if (sourceType.isDynamicDim(dim)) { 777 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 778 padOp.getLoc(), padOp.source(), dim)); 779 } else { 780 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 781 } 782 } 783 // Strides of InsertSliceOp are all 1. 784 SmallVector<OpFoldResult> strides(sourceType.getRank(), 785 rewriter.getIndexAttr(1)); 786 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 787 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 788 789 return success(); 790 } 791 792 /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute, 793 /// it must be of type Integer. 794 static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) { 795 if (auto val = ofr.dyn_cast<Value>()) 796 return val; 797 auto intVal = getConstantIntValue(ofr); 798 assert(intVal && "expected Value or IntegerAttr"); 799 return builder.create<ConstantIndexOp>(loc, *intVal); 800 } 801 802 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 803 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 804 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 805 if (!padOp) 806 return failure(); 807 // Only unit stride supported. 808 if (!sliceOp.hasUnitStride()) 809 return failure(); 810 // Only constant padding value supported. 811 Value padValue = padOp.getConstantPaddingValue(); 812 if (!padValue) 813 return failure(); 814 815 // Helper variables and functions for various arithmetic operations. These are 816 // used extensively for computing new offset/length and padding values. 817 Location loc = sliceOp.getLoc(); 818 AffineExpr dim0, dim1; 819 bindDims(rewriter.getContext(), dim0, dim1); 820 // Add two integers. 821 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 822 auto add = [&](Value v1, Value v2) { 823 return rewriter.createOrFold<AffineApplyOp>(loc, addMap, 824 ValueRange{v1, v2}); 825 }; 826 // Subtract two integers. 827 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 828 auto sub = [&](Value v1, Value v2) { 829 return rewriter.createOrFold<AffineApplyOp>(loc, subMap, 830 ValueRange{v1, v2}); 831 }; 832 // Take the minimum of two integers. 833 auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 834 auto min = [&](Value v1, Value v2) { 835 return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2}); 836 }; 837 // Take the maximum of two integers. 838 auto max = [&](Value v1, Value v2) { 839 return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2}); 840 }; 841 // Zero index-typed integer. 842 auto zero = rewriter.create<ConstantIndexOp>(loc, 0); 843 844 // Helper function for filling static/dynamic low/high padding indices vectors 845 // of PadTensorOp. 846 auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices, 847 SmallVector<int64_t> &staticIndices) { 848 if (auto constInt = getConstantIntValue(val)) { 849 staticIndices.push_back(*constInt); 850 } else { 851 staticIndices.push_back(ShapedType::kDynamicSize); 852 dynIndices.push_back(val); 853 } 854 }; 855 856 // Compute new offsets, lengths, low padding, high padding. 857 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 858 SmallVector<Value> newLows, newHighs; 859 SmallVector<int64_t> staticNewLows, staticNewHighs; 860 // Set to true if the original data source is not read at all. 861 bool hasZeroLen = false; 862 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 863 // is true if the original data source turns out to be unused at runtime. 864 Value dynHasZeroLenCond; 865 866 int64_t rank = padOp.getSourceType().getRank(); 867 for (unsigned dim = 0; dim < rank; ++dim) { 868 auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]); 869 bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0); 870 auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]); 871 bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0); 872 auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]); 873 auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]); 874 auto srcSize = 875 rewriter.createOrFold<tensor::DimOp>(loc, padOp.source(), dim); 876 877 // The new amount of low padding is `low - offset`. Except for the case 878 // where none of the low padding is read. In that case, the new amount of 879 // low padding is zero. 880 // 881 // Optimization: If low = 0, then newLow = 0. 882 Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 883 appendIndex(newLow, newLows, staticNewLows); 884 885 // Start reading the data from position `offset - low`. Since the original 886 // read may have started in the low padding zone, this value could be 887 // negative. Therefore, start reading from: 888 // 889 // max(offset - low, 0) 890 // 891 // The original read could also have started in the high padding zone. 892 // In that case, set the offset to the end of source tensor. The new 893 // ExtractSliceOp length will be zero in that case. (Effectively reading no 894 // data from the source.) 895 // 896 // Optimization: If low = 0, then the formula can be simplified. 897 Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) 898 : min(offset, srcSize); 899 newOffsets.push_back(getAsOpFoldResult(newOffset)); 900 901 // The original ExtractSliceOp was reading until position `offset + length`. 902 // Therefore, the corresponding position within the source tensor is: 903 // 904 // offset + length - low 905 // 906 // In case the original ExtractSliceOp stopped reading within the low 907 // padding zone, this value can be negative. In that case, the end position 908 // of the read should be zero. (Similar to newOffset.) 909 // 910 // The original read could also have stopped in the high padding zone. 911 // In that case, set the end positition of the read should be the end of the 912 // source tensor. (Similar to newOffset.) 913 // 914 // endLoc = min(max(offset - low + length, 0), srcSize) 915 // 916 // The new ExtractSliceOp length is `endLoc - newOffset`. 917 // 918 // Optimization: If low = 0, then the formula can be simplified. 919 Value endLoc = hasLowPad 920 ? min(max(add(sub(offset, low), length), zero), srcSize) 921 : min(add(offset, length), srcSize); 922 Value newLength = sub(endLoc, newOffset); 923 newLengths.push_back(getAsOpFoldResult(newLength)); 924 925 // Check if newLength is zero. In that case, no SubTensorOp should be 926 // executed. 927 if (auto newLengthInt = getConstantIntValue(newLength)) { 928 hasZeroLen |= *newLengthInt == 0; 929 } else { 930 Value check = rewriter.create<CmpIOp>( 931 loc, CmpIPredicate::eq, newLength, zero); 932 dynHasZeroLenCond = 933 dynHasZeroLenCond 934 ? rewriter.create<OrOp>(loc, check, dynHasZeroLenCond) 935 : check; 936 } 937 938 // The amount of high padding is simply the number of elements remaining, 939 // so that the result has the same length as the original ExtractSliceOp. 940 // As an optimization, if the original high padding is zero, then the new 941 // high padding must also be zero. 942 Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; 943 appendIndex(newHigh, newHighs, staticNewHighs); 944 945 // Only unit stride supported. 946 newStrides.push_back(rewriter.getIndexAttr(1)); 947 } 948 949 // Insert cast to ensure that types match. (May be folded away.) 950 auto castResult = [&](Value val) -> Value { 951 auto castOp = rewriter.create<tensor::CastOp>(loc, sliceOp.getType(), val); 952 return castOp; 953 }; 954 955 // In cases where the original data source is unused: Emit a GenerateOp and 956 // do not generate a SliceOp. (The result shape of the SliceOp would 957 // have a dimension of size 0, the semantics of which is unclear.) 958 auto createGenerateOp = [&]() { 959 // The shape of the GenerateOp is the same as the existing SliceOp. 960 RankedTensorType type = sliceOp.getType(); 961 SmallVector<Value> dynDims; 962 for (unsigned i = 0; i < type.getRank(); ++i) { 963 if (type.isDynamicDim(i)) 964 dynDims.push_back(asValue(rewriter, loc, sliceOp.getMixedSizes()[i])); 965 } 966 967 // Create GenerateOp. 968 auto generateOp = rewriter.create<tensor::GenerateOp>(loc, type, dynDims); 969 970 // Copy region to new op. 971 BlockAndValueMapping bvm; 972 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 973 // Rewrite linalg::YieldOp to tensor::YieldOp. 974 { 975 OpBuilder::InsertionGuard guard(rewriter); 976 auto yieldOp = dyn_cast<linalg::YieldOp>( 977 generateOp.getRegion().front().getTerminator()); 978 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 979 assert(yieldOp.values().size() == 1); 980 rewriter.setInsertionPoint(yieldOp); 981 rewriter.replaceOpWithNewOp<tensor::YieldOp>( 982 yieldOp, yieldOp.values()[0]); 983 } 984 985 return castResult(generateOp); 986 }; 987 988 // Emit a SliceOp and a PadTensorOp. Should not be used in cases where 989 // the result shape of the new SliceOp has a zero dimension. 990 auto createPadTensorOfSubTensor = [&]() { 991 // Create pad_tensor(subtensor(x)). 992 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 993 loc, padOp.source(), newOffsets, newLengths, newStrides); 994 auto newPadTensorOp = rewriter.create<PadTensorOp>( 995 loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs); 996 997 // Copy region to new PadTensorOp. 998 BlockAndValueMapping bvm; 999 padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm); 1000 1001 // Cast result and return. 1002 return castResult(newPadTensorOp); 1003 }; 1004 1005 // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known 1006 // that the original data source x is not used. 1007 if (hasZeroLen) { 1008 rewriter.replaceOp(sliceOp, createGenerateOp()); 1009 return success(); 1010 } 1011 1012 // If there are dynamic dimensions: Generate an scf.if check to avoid creating 1013 // SliceOps with result dimensions of size 0 at runtime. 1014 if (dynHasZeroLenCond) { 1015 auto result = rewriter.create<scf::IfOp>( 1016 loc, sliceOp.getType(), dynHasZeroLenCond, 1017 /*thenBuilder=*/ 1018 [&](OpBuilder &b, Location loc) { 1019 b.create<scf::YieldOp>(loc, createGenerateOp()); 1020 }, 1021 /*elseBuilder=*/ 1022 [&](OpBuilder &b, Location loc) { 1023 b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor()); 1024 }); 1025 rewriter.replaceOp(sliceOp, result.getResult(0)); 1026 return success(); 1027 } 1028 1029 // All shapes are static and the data source is actually used. Rewrite into 1030 // pad_tensor(subtensor(x)). 1031 rewriter.replaceOp(sliceOp, createPadTensorOfSubTensor()); 1032 return success(); 1033 } 1034