1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <type_traits>
32 
33 #define DEBUG_TYPE "linalg-transforms"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39 
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45     "__internal_linalg_transform__";
46 
47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
48     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
49     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50       replacement(replacement) {}
51 
52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
53     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
54     Optional<Identifier> replacement)
55     : filters(),
56       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement) {
58   if (f)
59     filters.push_back(f);
60 }
61 
62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
63     PatternRewriter &rewriter, Operation *op) const {
64   if (llvm::any_of(filters,
65                    [&](const FilterFunction &f) { return failed(f(op)); }))
66     return failure();
67 
68   auto attr = op->template getAttrOfType<StringAttr>(
69       LinalgTransforms::kLinalgTransformMarker);
70 
71   if (!attr) {
72     // 1. Has no filter case and matchDisjunction is empty.
73     if (matchDisjunction.empty())
74       return success();
75 
76     // 2. Has no filter but was expecting a filter.
77     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
78       diag << " does not have any filter from list: ";
79       interleaveComma(matchDisjunction, diag);
80     });
81   }
82 
83   // 4. Match explicit filter.
84   for (auto filter : matchDisjunction)
85     if (attr.getValue() == filter)
86       return success();
87 
88   // 5. Fail to match.
89   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
90     diag << " does not have any filter from list: ";
91     interleaveComma(matchDisjunction, diag);
92   });
93 }
94 
95 void mlir::linalg::LinalgTransformationFilter::
96     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
97                                       Operation *op) const {
98   if (replacement.hasValue())
99     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
100                 rewriter.getStringAttr(replacement.getValue().strref()));
101   else
102     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
103                                    rewriter.getContext()));
104 }
105 
106 LinalgTilingOptions &
107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
108   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
109   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
110     OpBuilder::InsertionGuard guard(b);
111     b.setInsertionPointToStart(
112         &op->getParentOfType<FuncOp>().getBody().front());
113     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
114       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
115       return v;
116     }));
117   };
118   return *this;
119 }
120 
121 /// Try to compute a static bounding box for `operand`
122 /// Return success if either:
123 ///   1. The operand is already statically shaped, `result` is left unchanged.
124 ///   2. The operand is (partially) dynamic, `result` is the result of a freshly
125 ///      created PadTensorOp.
126 /// Return failure if the operand cannot be padded to a static shape.
127 static LogicalResult padOperandToSmallestStaticBoundingBox(
128     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
129     const LinalgTilingOptions &options, Value &result) {
130   // Already static shape, no need to pad.
131   if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
132     return success();
133   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
134   // Not a slice op, cannot construct a static bounding box.
135   if (!sliceOp)
136     return failure();
137   SmallVector<int64_t> staticSizes;
138   staticSizes.reserve(opToPad.getRank(opOperand));
139   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
140   for (auto size : shapedOp.getMixedSizes()) {
141     auto indexAttr = size.is<Attribute>()
142                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
143                          : linalg::getSmallestBoundingIndex(size.get<Value>());
144     // SmallestBoundingIndex must exist for all sizes.
145     // For now return an error if we can't find it.
146     if (!indexAttr)
147       return rewriter.notifyMatchFailure(
148           opToPad, "No constant bounding box can be found for padding");
149     staticSizes.push_back(indexAttr.getInt());
150   }
151   Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
152   auto staticTensorType = RankedTensorType::get(
153       staticSizes, getElementTypeOrSelf(opOperand->get()));
154   result = linalg::PadTensorOp::createPadHighOp(
155       staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
156   return success();
157 }
158 
159 // Try to create a static bounding box around each operand of `res.op`.
160 // If successful, `res.op` is rewritten in static form with padded operands.
161 // `res.op` is updated to the cloned static form of the op on success.
162 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
163                                        TiledLinalgOp &res,
164                                        const LinalgTilingOptions &options) {
165   LinalgOp opToPad = res.op;
166   Location loc = opToPad->getLoc();
167 
168   // If the op is fully static, it does not need padding.
169   // TODO: there are cases where we may still want to pad to larger sizes.
170   assert(opToPad.hasTensorSemantics() &&
171          "expected operation to have tensor semantics");
172   if (!opToPad.hasDynamicShape())
173     return success();
174 
175   OpBuilder::InsertionGuard g(rewriter);
176   // Set IP after op because we also take the dims of the original output.
177   rewriter.setInsertionPointAfter(opToPad);
178   // Make a copy of the shaped operands and update it.
179   SmallVector<Value> newOperands;
180   newOperands.reserve(opToPad.getNumInputsAndOutputs());
181   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
182     Value paddedOperand;
183     // If padding was requested but the shape cannot be bounded statically then
184     // the pattern fails to apply.
185     if (failed(padOperandToSmallestStaticBoundingBox(
186             rewriter, opToPad, opOperand, options, paddedOperand)))
187       return failure();
188     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
189   }
190 
191   // Clone `opToPad` to operate on the statically padded shapes.
192   auto resultTensorTypes =
193       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
194   linalg::LinalgOp paddedOp =
195       opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
196 
197   // Recover the slice out of the new static results. This keeps the original
198   // linalg op around because it uses the dims of the original results.
199   // This later folds away.
200   SmallVector<Value> paddedSubviewResults;
201   paddedSubviewResults.reserve(opToPad->getNumResults());
202   SetVector<Operation *> newUsersOfOpToPad;
203   for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
204     auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
205     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
206     auto sizes = llvm::to_vector<4>(llvm::map_range(
207         llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
208           auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d);
209           newUsersOfOpToPad.insert(dimOp);
210           return dimOp.getResult();
211         }));
212     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
213     paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
214         loc, std::get<1>(it), offsets, sizes, strides));
215   }
216   // Replace the transient `opToPad` locally, except for uses that we just
217   // created for the purpose of extracting the dims.
218   rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
219     return !newUsersOfOpToPad.contains(opOp.getOwner());
220   });
221 
222   res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
223   return success();
224 }
225 
226 /// Linalg base tiling pattern.
227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
228     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
229     LinalgTransformationFilter filter, PatternBenefit benefit)
230     : RewritePattern(opName, benefit, context), filter(filter),
231       options(options) {}
232 
233 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
234     MLIRContext *context, LinalgTilingOptions options,
235     LinalgTransformationFilter filter, PatternBenefit benefit)
236     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
237       options(options) {}
238 
239 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
240     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
241   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
242   if (!linalgOp)
243     return failure();
244   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
245     return failure();
246 
247   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
248 
249   if (!res)
250     return failure();
251 
252   // Setup RAII guard to return properly.
253   LinalgOp tiledOp = res->op;
254   auto guard = llvm::make_scope_exit([&]() {
255     // Return relevant information to derived pattern.
256     result = *res;
257     // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
258     filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
259     if (tiledOp != res->op)
260       filter.replaceLinalgTransformationFilter(rewriter, res->op);
261   });
262 
263   // Consider padding on the fly only if the op has tensor semantics.
264   if (!options.paddingValueComputationFunction ||
265       !linalgOp.hasTensorSemantics())
266     return success();
267 
268   // Try to pad on the fly by rewriting res->op as a padded op.
269   if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
270     // Set so RAII guard does not propagate TiledLinalgOp to `result`.
271     return failure();
272   }
273 
274   // Do not perform replacement of `linalgOp`, let the derived patterns
275   // do this as they see fit, from the resulting TiledLinalgOp.
276   return success();
277 }
278 
279 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
280   if (tiledOp.loops.empty())
281     return tiledOp.op.getOperation()->getResults();
282   return tiledOp.loops.front()->getResults();
283 }
284 
285 static ValueRange
286 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
287   if (tiledAndFusedOp.fusedLoops.empty())
288     return tiledAndFusedOp.op.getOperation()->getResults();
289   return tiledAndFusedOp.fusedLoops.front()->getResults();
290 }
291 
292 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
293     StringRef opName, MLIRContext *context,
294     const LinalgDependenceGraph &dependenceGraph,
295     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
296     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
297     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
298     : RewritePattern(opName, benefit, context, {}),
299       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
300       fusionOptions(fusionOptions), filter(filter),
301       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
302 
303 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
304     Operation *op, PatternRewriter &rewriter) const {
305   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
306   // TODO: remove hasIndexSemantics check once index ops are supported.
307   if (!linalgOp || linalgOp.hasIndexSemantics())
308     return failure();
309   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
310     return failure();
311 
312   DenseSet<Operation *> producers;
313   producers.insert(linalgOp);
314   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
315     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
316     // When looking at dependences into, indexingOp is always OpOperand. We
317     // could assert, but continue if this is not the case.
318     if (!operandNumber)
319       continue;
320     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
321       continue;
322     if (isa<LinalgOp>(dependence.getDependentOp()))
323       producers.insert(dependence.getDependentOp());
324   }
325 
326   SmallVector<LinalgOp, 1> fusionOps;
327   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
328        ++it) {
329     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
330     if (producerLinalgOp && producers.count(producerLinalgOp))
331       fusionOps.push_back(producerLinalgOp);
332   }
333   fusionOps.push_back(linalgOp);
334 
335   SmallVector<Value, 4> tileSizes =
336       tilingOptions.tileSizeComputationFunction(rewriter, op);
337   LinalgTilingOptions instanceTilingOptions = tilingOptions;
338   instanceTilingOptions.setTileSizes(tileSizes);
339   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
340       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
341   if (!tiledAndFusedOps)
342     return failure();
343 
344   // Tile the unfused loops;
345   SmallVector<Value, 4> unfusedLoopTileSizes;
346   Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
347   for (auto tileSize : enumerate(tileSizes)) {
348     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
349       unfusedLoopTileSizes.push_back(zero);
350     else
351       unfusedLoopTileSizes.push_back(tileSize.value());
352   }
353   // Tile the loop only if there is a non-zero tile size.
354   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
355     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
356   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
357         if (auto cst = val.getDefiningOp<ConstantIndexOp>())
358           return cst.getValue() != 0;
359         return true;
360       })) {
361     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
362     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
363     Optional<TiledLinalgOp> unfusedTiledOp =
364         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
365     if (!unfusedTiledOp)
366       return failure();
367     rewriter.replaceOp(tiledAndFusedOps->op,
368                        getTiledOpResult(unfusedTiledOp.getValue()));
369     tiledAndFusedOps->op = unfusedTiledOp->op;
370   }
371   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
372 
373   filter.replaceLinalgTransformationFilter(rewriter,
374                                            tiledAndFusedOps->op.getOperation());
375   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
376     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
377                                                     fusedOp.getOperation());
378   }
379   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
380     originalOpMarker.replaceLinalgTransformationFilter(
381         rewriter, origProducerOp.getOperation());
382   }
383   rewriter.updateRootInPlace(op, [&]() {
384     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
385   });
386   return success();
387 }
388 
389 /// Linalg generic interchange pattern.
390 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
391     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
392     LinalgTransformationFilter filter, PatternBenefit benefit)
393     : OpRewritePattern(context, benefit), filter(filter),
394       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
395 
396 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
397     GenericOp genericOp, PatternRewriter &rewriter) const {
398   if (failed(filter.checkAndNotify(rewriter, genericOp)))
399     return failure();
400   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
401     return failure();
402 
403   // TODO: figure out how this interplays with named ops. In particular this
404   // should break the named op property.
405   rewriter.updateRootInPlace(genericOp, [&]() {
406     interchangeGenericOp(rewriter, genericOp, interchangeVector);
407     // New filter if specified.
408     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
409   });
410   return success();
411 }
412 
413 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
414     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
415     LinalgTransformationFilter filter, PatternBenefit benefit)
416     : RewritePattern(opName, benefit, context, {}), filter(filter),
417       options(options) {}
418 
419 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
420     Operation *op, PatternRewriter &rewriter) const {
421   if (failed(filter.checkAndNotify(rewriter, op)))
422     return failure();
423   if (failed(promoteSubviewsPrecondition(op, options)))
424     return failure();
425 
426   // TODO: We cannot use root update here. This pattern is creating other ops,
427   // so if the promotion fails, those need to be cleaned up, which doesnt seem
428   // to be happening here. So to fail properly, we should be cloning the op and
429   // deleting the previous op. This needs more investigation.
430   rewriter.startRootUpdate(op);
431   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
432   if (!promotedOp) {
433     rewriter.cancelRootUpdate(op);
434     return op->emitError("subview promotion failed");
435   }
436   rewriter.finalizeRootUpdate(op);
437   filter.replaceLinalgTransformationFilter(rewriter, op);
438   return success();
439 }
440 
441 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
442     MLIRContext *context, LinalgTransformationFilter filter,
443     PatternBenefit benefit)
444     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
445 
446 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
447     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
448     PatternBenefit benefit)
449     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
450 
451 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
452     Operation *op, PatternRewriter &rewriter) const {
453   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
454   if (!linalgOp)
455     return failure();
456   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
457     return failure();
458   SmallVector<Value> newResults;
459   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
460     return failure();
461   if (!newResults.empty())
462     rewriter.replaceOp(op, newResults);
463   else
464     rewriter.eraseOp(op);
465   return success();
466 }
467 
468 LogicalResult mlir::linalg::applyStagedPatterns(
469     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
470     const FrozenRewritePatternSet &stage2Patterns,
471     function_ref<LogicalResult(Operation *)> stage3Lambda) {
472   unsigned iteration = 0;
473   (void)iteration;
474   for (const auto &patterns : stage1Patterns) {
475     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
476                       << *op);
477     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
478       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
479       return failure();
480     }
481     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
482                       << *op);
483     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
484       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
485       return failure();
486     }
487     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
488                       << *op);
489     if (stage3Lambda) {
490       if (failed(stage3Lambda(op)))
491         return failure();
492       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
493                         << *op);
494     }
495   }
496   return success();
497 }
498 
499 /// Traverse the `dims` and substitute known min or max expressions returned by
500 /// the lambda |getMinMaxExpr|.
501 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
502                             SmallVectorImpl<Value> &symbols,
503                             GetMinMaxExprFn getMinMaxExpr) {
504   auto exprs = llvm::to_vector<4>(map.getResults());
505   for (AffineExpr &expr : exprs) {
506     bool substituted = true;
507     while (substituted) {
508       substituted = false;
509       for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
510         Value dim = dims[dimIdx];
511         auto minMax = getMinMaxExpr(dim, dims, symbols);
512         if (!minMax)
513           continue;
514         AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
515         LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
516         LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
517         // Substitute occurrences of `dimExpr` by either the min expression or
518         // the max expression depending on whether the value is used with a
519         // positive or negative  coefficient.
520         AffineExpr substitutedExpr =
521             substWithMin(expr, dimExpr, minMax->first, minMax->second);
522         LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n");
523         substituted = (substitutedExpr != expr);
524         expr = substitutedExpr;
525       }
526     }
527 
528     // Cleanup and simplify the results.
529     // This needs to happen outside of the loop iterating on dims.size() since
530     // it modifies dims.
531     SmallVector<Value, 4> operands(dims.begin(), dims.end());
532     operands.append(symbols.begin(), symbols.end());
533     auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
534                               exprs.front().getContext());
535 
536     LLVM_DEBUG({
537       DBGS() << "Map to simplify: " << map << "\n";
538       DBGS() << "Operands:\n";
539       for (Value v : operands)
540         DBGS() << v << "\n";
541     });
542 
543     // Pull in affine.apply operations and compose them fully into the
544     // result.
545     fullyComposeAffineMapAndOperands(&map, &operands);
546     canonicalizeMapAndOperands(&map, &operands);
547     map = simplifyAffineMap(map);
548     // Assign the results.
549     exprs.assign(map.getResults().begin(), map.getResults().end());
550     dims.assign(operands.begin(), operands.begin() + map.getNumDims());
551     symbols.assign(operands.begin() + map.getNumDims(), operands.end());
552 
553     LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
554   }
555 
556   assert(!exprs.empty() && "Unexpected empty exprs");
557   return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
558 }
559 
560 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute
561 /// dimensions with known range by new expressions involving the min or max
562 /// expression:
563 ///   - If the AffineDimExpr mapped to a known value has a positive sign, it
564 ///     is replaced by the min expression.
565 ///   - If the AffineDimExpr mapped to a known value has a negative sign, it is
566 ///     replaced by the max expression.
567 /// All known values are iteratively replaced.
568 /// This is used as an intermediate step in computing bounding boxes and
569 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
570 /// positive values (positive orthant assumptions).
571 /// Return a new AffineMap, dims and symbols that have been canonicalized and
572 /// simplified.
573 AffineMapAndOperands
574 mlir::linalg::substituteMin(AffineMinOp affineMinOp,
575                             GetMinMaxExprFn getMinMaxExpr) {
576   AffineMapAndOperands res{affineMinOp.getAffineMap(),
577                            SmallVector<Value>(affineMinOp.getDimOperands()),
578                            SmallVector<Value>(affineMinOp.getSymbolOperands())};
579   res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
580                        getMinMaxExpr);
581   return res;
582 }
583 
584 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
585     AffineMinOp minOp, PatternRewriter &rewriter) const {
586   LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
587                     << "\n");
588 
589   auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn);
590   AffineMap map = affineMapAndOperands.map;
591 
592   LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
593 
594   // Check whether any of the expressions, when subtracted from all other
595   // expressions, produces only >= 0 constants. If so, it is the min.
596   for (auto e : minOp.getAffineMap().getResults()) {
597     LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
598     if (!e.isSymbolicOrConstant())
599       continue;
600 
601     auto isNonPositive = [](AffineExpr e) {
602       if (auto cst = e.dyn_cast<AffineConstantExpr>())
603         return cst.getValue() < 0;
604       return true;
605     };
606 
607     // Build the subMap and check everything is statically known to be
608     // positive.
609     SmallVector<AffineExpr, 4> subExprs;
610     subExprs.reserve(map.getNumResults());
611     for (auto ee : map.getResults())
612       subExprs.push_back(ee - e);
613     MLIRContext *ctx = minOp.getContext();
614     AffineMap subMap = simplifyAffineMap(
615         AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
616     LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
617     if (llvm::any_of(subMap.getResults(), isNonPositive))
618       continue;
619 
620     // Static min found.
621     if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
622       rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
623     } else {
624       auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
625       SmallVector<Value> resultOperands = affineMapAndOperands.dims;
626       llvm::append_range(resultOperands, affineMapAndOperands.symbols);
627       canonicalizeMapAndOperands(&resultMap, &resultOperands);
628       resultMap = simplifyAffineMap(resultMap);
629       rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
630                                                  resultOperands);
631     }
632     return success();
633   }
634 
635   return failure();
636 }
637 
638 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
639   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
640 }
641 
642 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
643 /// with pad_val) and GenericOp (to copy contents).
644 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
645     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
646 
647   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
648   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
649 
650   // Bail on non-static shapes.
651   if (!inputShapedType.hasStaticShape())
652     return failure();
653   if (!resultShapedType.hasStaticShape())
654     return failure();
655 
656   // Only support padding with a constant for now, i.e. either:
657   //   1. A BBarg from a different block.
658   //   2. A value defined outside of the current block.
659   Block &block = padOp.region().front();
660   auto yieldOp = cast<YieldOp>(block.getTerminator());
661   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
662   Value padValue = yieldOp.values().front();
663   Operation *definingOp = padValue.getDefiningOp();
664   if (definingOp && definingOp->getBlock() == &block)
665     return failure();
666   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
667     return failure();
668 
669   // Create tensor with the padded shape
670   Location loc = padOp.getLoc();
671   SmallVector<Value> indices(resultShapedType.getRank(),
672                              rewriter.create<ConstantIndexOp>(loc, 0));
673   Value initTensor = rewriter.create<InitTensorOp>(
674       loc, resultShapedType.getShape(), resultShapedType.getElementType());
675 
676   // Initialize tensor with the pad value
677   Value tmpTensor =
678       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
679 
680   // Copy original contents into new tensor
681   // Uses linalg.generic, but could be done with tensor.insert_slice
682   SmallVector<AffineExpr, 4> outputExprs;
683   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
684     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
685                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
686   }
687 
688   SmallVector<AffineMap, 2> transferMaps = {
689       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
690       AffineMap::get(resultShapedType.getRank(),
691                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
692 
693   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
694       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
695       getNParallelLoopsAttrs(resultShapedType.getRank()),
696       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
697         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
698       });
699 
700   return success();
701 }
702 
703 /// Filling `dest` using FillOp constant padding value if possible.
704 /// Otherwise, generate a tensor::GenerateOp.
705 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
706     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
707     const SmallVector<Value> &dynSizes) const {
708   auto padValue = padOp.getConstantPaddingValue();
709   if (padValue)
710     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
711 
712   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
713   auto generateOp = rewriter.create<tensor::GenerateOp>(
714       padOp.getLoc(), padOp.getResultType(), dynSizes);
715   // Copy region to new op.
716   BlockAndValueMapping bvm;
717   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
718   // Rewrite linalg::YieldOp to tensor::YieldOp.
719   OpBuilder::InsertionGuard guard(rewriter);
720   auto yieldOp =
721       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
722   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
723   assert(yieldOp.values().size() == 1);
724   rewriter.setInsertionPoint(yieldOp);
725   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
726   return generateOp;
727 }
728 
729 LogicalResult
730 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
731                                               PatternRewriter &rewriter) const {
732   // Given an OpFoldResult, return an index-typed value.
733   auto getIdxValue = [&](OpFoldResult ofr) {
734     if (auto val = ofr.dyn_cast<Value>())
735       return val;
736     return rewriter
737         .create<ConstantIndexOp>(
738             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
739         .getResult();
740   };
741 
742   auto resultType = padOp.getResultType();
743   // Compute size of InitTensorOp. Any combination of static/dynamic is
744   // supported.
745   SmallVector<Value> dynSizes;
746   SmallVector<int64_t> staticSizes;
747   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
748     if (resultType.isDynamicDim(dim)) {
749       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
750                                                           padOp.source(), dim);
751       // Add low and high padding value.
752       auto plusLow = rewriter.createOrFold<AddIOp>(
753           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
754       auto plusHigh = rewriter.createOrFold<AddIOp>(
755           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
756       dynSizes.push_back(plusHigh);
757     }
758     staticSizes.push_back(resultType.getDimSize(dim));
759   }
760 
761   // Init tensor and fill it with padding.
762   Value init = rewriter.create<InitTensorOp>(
763       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
764   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
765 
766   // Try optimize the copy of source.
767   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
768     return success();
769 
770   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
771   // for copying the PadOp source.
772   auto sourceType = padOp.getSourceType();
773   // Compute size of source of PadTensorOp.
774   SmallVector<OpFoldResult> srcSizes;
775   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
776     if (sourceType.isDynamicDim(dim)) {
777       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
778           padOp.getLoc(), padOp.source(), dim));
779     } else {
780       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
781     }
782   }
783   // Strides of InsertSliceOp are all 1.
784   SmallVector<OpFoldResult> strides(sourceType.getRank(),
785                                     rewriter.getIndexAttr(1));
786   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
787       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
788 
789   return success();
790 }
791 
792 /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute,
793 /// it must be of type Integer.
794 static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
795   if (auto val = ofr.dyn_cast<Value>())
796     return val;
797   auto intVal = getConstantIntValue(ofr);
798   assert(intVal && "expected Value or IntegerAttr");
799   return builder.create<ConstantIndexOp>(loc, *intVal);
800 }
801 
802 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
803     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
804   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
805   if (!padOp)
806     return failure();
807   // Only unit stride supported.
808   if (!sliceOp.hasUnitStride())
809     return failure();
810   // Only constant padding value supported.
811   Value padValue = padOp.getConstantPaddingValue();
812   if (!padValue)
813     return failure();
814 
815   // Helper variables and functions for various arithmetic operations. These are
816   // used extensively for computing new offset/length and padding values.
817   Location loc = sliceOp.getLoc();
818   AffineExpr dim0, dim1;
819   bindDims(rewriter.getContext(), dim0, dim1);
820   // Add two integers.
821   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
822   auto add = [&](Value v1, Value v2) {
823     return rewriter.createOrFold<AffineApplyOp>(loc, addMap,
824                                                 ValueRange{v1, v2});
825   };
826   // Subtract two integers.
827   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
828   auto sub = [&](Value v1, Value v2) {
829     return rewriter.createOrFold<AffineApplyOp>(loc, subMap,
830                                                 ValueRange{v1, v2});
831   };
832   // Take the minimum of two integers.
833   auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
834   auto min = [&](Value v1, Value v2) {
835     return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
836   };
837   // Take the maximum of two integers.
838   auto max = [&](Value v1, Value v2) {
839     return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
840   };
841   // Zero index-typed integer.
842   auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
843 
844   // Helper function for filling static/dynamic low/high padding indices vectors
845   // of PadTensorOp.
846   auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
847                          SmallVector<int64_t> &staticIndices) {
848     if (auto constInt = getConstantIntValue(val)) {
849       staticIndices.push_back(*constInt);
850     } else {
851       staticIndices.push_back(ShapedType::kDynamicSize);
852       dynIndices.push_back(val);
853     }
854   };
855 
856   // Compute new offsets, lengths, low padding, high padding.
857   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
858   SmallVector<Value> newLows, newHighs;
859   SmallVector<int64_t> staticNewLows, staticNewHighs;
860   // Set to true if the original data source is not read at all.
861   bool hasZeroLen = false;
862   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
863   // is true if the original data source turns out to be unused at runtime.
864   Value dynHasZeroLenCond;
865 
866   int64_t rank = padOp.getSourceType().getRank();
867   for (unsigned dim = 0; dim < rank; ++dim) {
868     auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]);
869     bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
870     auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]);
871     bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
872     auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]);
873     auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]);
874     auto srcSize =
875         rewriter.createOrFold<tensor::DimOp>(loc, padOp.source(), dim);
876 
877     // The new amount of low padding is `low - offset`. Except for the case
878     // where none of the low padding is read. In that case, the new amount of
879     // low padding is zero.
880     //
881     // Optimization: If low = 0, then newLow = 0.
882     Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
883     appendIndex(newLow, newLows, staticNewLows);
884 
885     // Start reading the data from position `offset - low`. Since the original
886     // read may have started in the low padding zone, this value could be
887     // negative. Therefore, start reading from:
888     //
889     // max(offset - low, 0)
890     //
891     // The original read could also have started in the high padding zone.
892     // In that case, set the offset to the end of source tensor. The new
893     // ExtractSliceOp length will be zero in that case. (Effectively reading no
894     // data from the source.)
895     //
896     // Optimization: If low = 0, then the formula can be simplified.
897     Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
898                                 : min(offset, srcSize);
899     newOffsets.push_back(getAsOpFoldResult(newOffset));
900 
901     // The original ExtractSliceOp was reading until position `offset + length`.
902     // Therefore, the corresponding position within the source tensor is:
903     //
904     // offset + length - low
905     //
906     // In case the original ExtractSliceOp stopped reading within the low
907     // padding zone, this value can be negative. In that case, the end position
908     // of the read should be zero. (Similar to newOffset.)
909     //
910     // The original read could also have stopped in the high padding zone.
911     // In that case, set the end positition of the read should be the end of the
912     // source tensor. (Similar to newOffset.)
913     //
914     // endLoc = min(max(offset - low + length, 0), srcSize)
915     //
916     // The new ExtractSliceOp length is `endLoc - newOffset`.
917     //
918     // Optimization: If low = 0, then the formula can be simplified.
919     Value endLoc = hasLowPad
920                        ? min(max(add(sub(offset, low), length), zero), srcSize)
921                        : min(add(offset, length), srcSize);
922     Value newLength = sub(endLoc, newOffset);
923     newLengths.push_back(getAsOpFoldResult(newLength));
924 
925     // Check if newLength is zero. In that case, no SubTensorOp should be
926     // executed.
927     if (auto newLengthInt = getConstantIntValue(newLength)) {
928       hasZeroLen |= *newLengthInt == 0;
929     } else {
930       Value check = rewriter.create<CmpIOp>(
931           loc, CmpIPredicate::eq, newLength, zero);
932       dynHasZeroLenCond =
933           dynHasZeroLenCond
934               ? rewriter.create<OrOp>(loc, check, dynHasZeroLenCond)
935               : check;
936     }
937 
938     // The amount of high padding is simply the number of elements remaining,
939     // so that the result has the same length as the original ExtractSliceOp.
940     // As an optimization, if the original high padding is zero, then the new
941     // high padding must also be zero.
942     Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
943     appendIndex(newHigh, newHighs, staticNewHighs);
944 
945     // Only unit stride supported.
946     newStrides.push_back(rewriter.getIndexAttr(1));
947   }
948 
949   // Insert cast to ensure that types match. (May be folded away.)
950   auto castResult = [&](Value val) -> Value {
951     auto castOp = rewriter.create<tensor::CastOp>(loc, sliceOp.getType(), val);
952     return castOp;
953   };
954 
955   // In cases where the original data source is unused: Emit a GenerateOp and
956   // do not generate a SliceOp. (The result shape of the SliceOp would
957   // have a dimension of size 0, the semantics of which is unclear.)
958   auto createGenerateOp = [&]() {
959     // The shape of the GenerateOp is the same as the existing SliceOp.
960     RankedTensorType type = sliceOp.getType();
961     SmallVector<Value> dynDims;
962     for (unsigned i = 0; i < type.getRank(); ++i) {
963       if (type.isDynamicDim(i))
964         dynDims.push_back(asValue(rewriter, loc, sliceOp.getMixedSizes()[i]));
965     }
966 
967     // Create GenerateOp.
968     auto generateOp  = rewriter.create<tensor::GenerateOp>(loc, type, dynDims);
969 
970     // Copy region to new op.
971     BlockAndValueMapping bvm;
972     padOp.region().cloneInto(&generateOp.getRegion(), bvm);
973     // Rewrite linalg::YieldOp to tensor::YieldOp.
974     {
975       OpBuilder::InsertionGuard guard(rewriter);
976       auto yieldOp = dyn_cast<linalg::YieldOp>(
977           generateOp.getRegion().front().getTerminator());
978       assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
979       assert(yieldOp.values().size() == 1);
980       rewriter.setInsertionPoint(yieldOp);
981       rewriter.replaceOpWithNewOp<tensor::YieldOp>(
982           yieldOp, yieldOp.values()[0]);
983     }
984 
985     return castResult(generateOp);
986   };
987 
988   // Emit a SliceOp and a PadTensorOp. Should not be used in cases where
989   // the result shape of the new SliceOp has a zero dimension.
990   auto createPadTensorOfSubTensor = [&]() {
991     // Create pad_tensor(subtensor(x)).
992     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
993         loc, padOp.source(), newOffsets, newLengths, newStrides);
994     auto newPadTensorOp = rewriter.create<PadTensorOp>(
995         loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs);
996 
997     // Copy region to new PadTensorOp.
998     BlockAndValueMapping bvm;
999     padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm);
1000 
1001     // Cast result and return.
1002     return castResult(newPadTensorOp);
1003   };
1004 
1005   // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
1006   // that the original data source x is not used.
1007   if (hasZeroLen) {
1008     rewriter.replaceOp(sliceOp, createGenerateOp());
1009     return success();
1010   }
1011 
1012   // If there are dynamic dimensions: Generate an scf.if check to avoid creating
1013   // SliceOps with result dimensions of size 0 at runtime.
1014   if (dynHasZeroLenCond) {
1015     auto result = rewriter.create<scf::IfOp>(
1016         loc, sliceOp.getType(), dynHasZeroLenCond,
1017         /*thenBuilder=*/
1018         [&](OpBuilder &b, Location loc) {
1019           b.create<scf::YieldOp>(loc, createGenerateOp());
1020         },
1021         /*elseBuilder=*/
1022         [&](OpBuilder &b, Location loc) {
1023           b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor());
1024         });
1025     rewriter.replaceOp(sliceOp, result.getResult(0));
1026     return success();
1027   }
1028 
1029   // All shapes are static and the data source is actually used. Rewrite into
1030   // pad_tensor(subtensor(x)).
1031   rewriter.replaceOp(sliceOp, createPadTensorOfSubTensor());
1032   return success();
1033 }
1034