1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <type_traits>
32 
33 #define DEBUG_TYPE "linalg-transforms"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39 
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45     "__internal_linalg_transform__";
46 
47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
48     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
49     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50       replacement(replacement) {}
51 
52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
53     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
54     Optional<Identifier> replacement)
55     : filters(),
56       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement) {
58   if (f)
59     filters.push_back(f);
60 }
61 
62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
63     PatternRewriter &rewriter, Operation *op) const {
64   if (llvm::any_of(filters,
65                    [&](const FilterFunction &f) { return failed(f(op)); }))
66     return failure();
67 
68   auto attr = op->template getAttrOfType<StringAttr>(
69       LinalgTransforms::kLinalgTransformMarker);
70 
71   if (!attr) {
72     // 1. Has no filter case and matchDisjunction is empty.
73     if (matchDisjunction.empty())
74       return success();
75 
76     // 2. Has no filter but was expecting a filter.
77     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
78       diag << " does not have any filter from list: ";
79       interleaveComma(matchDisjunction, diag);
80     });
81   }
82 
83   // 4. Match explicit filter.
84   for (auto filter : matchDisjunction)
85     if (attr.getValue() == filter)
86       return success();
87 
88   // 5. Fail to match.
89   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
90     diag << " does not have any filter from list: ";
91     interleaveComma(matchDisjunction, diag);
92   });
93 }
94 
95 void mlir::linalg::LinalgTransformationFilter::
96     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
97                                       Operation *op) const {
98   if (replacement.hasValue())
99     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
100                 rewriter.getStringAttr(replacement.getValue().strref()));
101   else
102     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
103                                    rewriter.getContext()));
104 }
105 
106 LinalgTilingOptions &
107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
108   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
109   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
110     OpBuilder::InsertionGuard guard(b);
111     b.setInsertionPointToStart(
112         &op->getParentOfType<FuncOp>().getBody().front());
113     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
114       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
115       return v;
116     }));
117   };
118   return *this;
119 }
120 
121 /// Try to compute a static bounding box for `operand`
122 /// Return success if either:
123 ///   1. The operand is already statically shaped, `result` is left unchanged.
124 ///   2. The operand is (partially) dynamic, `result` is the result of a freshly
125 ///      created PadTensorOp.
126 /// Return failure if the operand cannot be padded to a static shape.
127 static LogicalResult padOperandToSmallestStaticBoundingBox(
128     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
129     const PaddingValueComputationFunction &paddingFunc, Value &result) {
130   // Already static shape, no need to pad.
131   if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
132     return success();
133   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
134   // Not a slice op, cannot construct a static bounding box.
135   if (!sliceOp)
136     return failure();
137   SmallVector<int64_t> staticSizes;
138   staticSizes.reserve(opToPad.getRank(opOperand));
139   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
140   for (auto size : shapedOp.getMixedSizes()) {
141     auto indexAttr = size.is<Attribute>()
142                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
143                          : linalg::getSmallestBoundingIndex(size.get<Value>());
144     // SmallestBoundingIndex must exist for all sizes.
145     // For now return an error if we can't find it.
146     if (!indexAttr)
147       return rewriter.notifyMatchFailure(
148           opToPad, "No constant bounding box can be found for padding");
149     staticSizes.push_back(indexAttr.getInt());
150   }
151   Value pad = paddingFunc(rewriter, *opOperand);
152   auto staticTensorType = RankedTensorType::get(
153       staticSizes, getElementTypeOrSelf(opOperand->get()));
154   result = linalg::PadTensorOp::createPadHighOp(
155       staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
156   return success();
157 }
158 
159 LogicalResult
160 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
161                           const PaddingValueComputationFunction &paddingFunc,
162                           LinalgOp &paddedOp) {
163   Location loc = opToPad->getLoc();
164 
165   // If the op is fully static, it does not need padding.
166   // TODO: there are cases where we may still want to pad to larger sizes.
167   assert(opToPad.hasTensorSemantics() &&
168          "expected operation to have tensor semantics");
169   if (!opToPad.hasDynamicShape())
170     return success();
171 
172   OpBuilder::InsertionGuard g(rewriter);
173   // Set IP after op because we also take the dims of the original output.
174   rewriter.setInsertionPointAfter(opToPad);
175   // Make a copy of the shaped operands and update it.
176   SmallVector<Value> newOperands;
177   newOperands.reserve(opToPad.getNumInputsAndOutputs());
178   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
179     Value paddedOperand;
180     // If padding was requested but the shape cannot be bounded statically then
181     // the pattern fails to apply.
182     if (failed(padOperandToSmallestStaticBoundingBox(
183             rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
184       return failure();
185     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
186   }
187 
188   SmallVector<SmallVector<Value>> reifiedResultShapes;
189   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
190                  .reifyResultShapes(rewriter, reifiedResultShapes)))
191     return failure();
192   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
193          "expected same number of results");
194 
195   // Clone `opToPad` to operate on the statically padded shapes.
196   auto resultTensorTypes =
197       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
198   paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
199 
200   // Recover the slice out of the new static results. This keeps the original
201   // linalg op around because it uses the dims of the original results.
202   SmallVector<Value> paddedSubviewResults;
203   paddedSubviewResults.reserve(opToPad->getNumResults());
204   for (auto en : llvm::enumerate(paddedOp->getResults())) {
205     Value paddedResult = en.value();
206     int64_t resultNumber = en.index();
207     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
208     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
209     SmallVector<OpFoldResult> sizes;
210     for (Value v : reifiedResultShapes[resultNumber])
211       sizes.push_back(v);
212     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
213     paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
214         loc, paddedResult, offsets, sizes, strides));
215   }
216   rewriter.replaceOp(opToPad, paddedSubviewResults);
217   return success();
218 }
219 
220 /// Linalg base tiling pattern.
221 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
222     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
223     LinalgTransformationFilter filter, PatternBenefit benefit)
224     : RewritePattern(opName, benefit, context), filter(filter),
225       options(options) {}
226 
227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
228     MLIRContext *context, LinalgTilingOptions options,
229     LinalgTransformationFilter filter, PatternBenefit benefit)
230     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
231       options(options) {}
232 
233 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
234     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
235   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
236   if (!linalgOp)
237     return failure();
238   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
239     return failure();
240 
241   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
242 
243   if (!res)
244     return failure();
245 
246   // Setup RAII guard to return properly.
247   LinalgOp paddedOp;
248   LinalgOp tiledOp = res->op;
249   auto guard = llvm::make_scope_exit([&]() {
250     // Return relevant information to derived pattern.
251     result = *res;
252     // Update filter.
253     if (paddedOp)
254       filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
255     else
256       filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
257   });
258 
259   // Consider padding on the fly only if the op has tensor semantics.
260   if (!options.paddingValueComputationFunction ||
261       !linalgOp.hasTensorSemantics())
262     return success();
263 
264   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
265   // `res.op` is rewritten in static form with padded operands.
266   if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
267                                   options.paddingValueComputationFunction,
268                                   paddedOp))) {
269     res->op = paddedOp;
270     // Do not perform replacement of `linalgOp`, let the derived patterns
271     // do this as they see fit, from the resulting TiledLinalgOp.
272     return success();
273   }
274   // Set so RAII guard does not propagate TiledLinalgOp to `result`.
275   return failure();
276 }
277 
278 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
279   if (tiledOp.loops.empty())
280     return tiledOp.op.getOperation()->getResults();
281   return tiledOp.loops.front()->getResults();
282 }
283 
284 static ValueRange
285 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
286   if (tiledAndFusedOp.fusedLoops.empty())
287     return tiledAndFusedOp.op.getOperation()->getResults();
288   return tiledAndFusedOp.fusedLoops.front()->getResults();
289 }
290 
291 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
292     StringRef opName, MLIRContext *context,
293     const LinalgDependenceGraph &dependenceGraph,
294     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
295     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
296     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
297     : RewritePattern(opName, benefit, context, {}),
298       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
299       fusionOptions(fusionOptions), filter(filter),
300       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
301 
302 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
303     Operation *op, PatternRewriter &rewriter) const {
304   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
305   // TODO: remove hasIndexSemantics check once index ops are supported.
306   if (!linalgOp || linalgOp.hasIndexSemantics())
307     return failure();
308   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
309     return failure();
310 
311   DenseSet<Operation *> producers;
312   producers.insert(linalgOp);
313   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
314     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
315     // When looking at dependences into, indexingOp is always OpOperand. We
316     // could assert, but continue if this is not the case.
317     if (!operandNumber)
318       continue;
319     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
320       continue;
321     if (isa<LinalgOp>(dependence.getDependentOp()))
322       producers.insert(dependence.getDependentOp());
323   }
324 
325   SmallVector<LinalgOp, 1> fusionOps;
326   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
327        ++it) {
328     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
329     if (producerLinalgOp && producers.count(producerLinalgOp))
330       fusionOps.push_back(producerLinalgOp);
331   }
332   fusionOps.push_back(linalgOp);
333 
334   SmallVector<Value, 4> tileSizes =
335       tilingOptions.tileSizeComputationFunction(rewriter, op);
336   LinalgTilingOptions instanceTilingOptions = tilingOptions;
337   instanceTilingOptions.setTileSizes(tileSizes);
338   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
339       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
340   if (!tiledAndFusedOps)
341     return failure();
342 
343   // Tile the unfused loops;
344   SmallVector<Value, 4> unfusedLoopTileSizes;
345   Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
346   for (auto tileSize : enumerate(tileSizes)) {
347     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
348       unfusedLoopTileSizes.push_back(zero);
349     else
350       unfusedLoopTileSizes.push_back(tileSize.value());
351   }
352   // Tile the loop only if there is a non-zero tile size.
353   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
354     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
355   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
356         if (auto cst = val.getDefiningOp<ConstantIndexOp>())
357           return cst.getValue() != 0;
358         return true;
359       })) {
360     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
361     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
362     Optional<TiledLinalgOp> unfusedTiledOp =
363         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
364     if (!unfusedTiledOp)
365       return failure();
366     rewriter.replaceOp(tiledAndFusedOps->op,
367                        getTiledOpResult(unfusedTiledOp.getValue()));
368     tiledAndFusedOps->op = unfusedTiledOp->op;
369   }
370   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
371 
372   filter.replaceLinalgTransformationFilter(rewriter,
373                                            tiledAndFusedOps->op.getOperation());
374   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
375     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
376                                                     fusedOp.getOperation());
377   }
378   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
379     originalOpMarker.replaceLinalgTransformationFilter(
380         rewriter, origProducerOp.getOperation());
381   }
382   rewriter.updateRootInPlace(op, [&]() {
383     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
384   });
385   return success();
386 }
387 
388 /// Linalg generic interchange pattern.
389 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
390     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
391     LinalgTransformationFilter filter, PatternBenefit benefit)
392     : OpRewritePattern(context, benefit), filter(filter),
393       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
394 
395 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
396     GenericOp genericOp, PatternRewriter &rewriter) const {
397   if (failed(filter.checkAndNotify(rewriter, genericOp)))
398     return failure();
399   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
400     return failure();
401 
402   // TODO: figure out how this interplays with named ops. In particular this
403   // should break the named op property.
404   rewriter.updateRootInPlace(genericOp, [&]() {
405     interchangeGenericOp(rewriter, genericOp, interchangeVector);
406     // New filter if specified.
407     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
408   });
409   return success();
410 }
411 
412 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
413     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
414     LinalgTransformationFilter filter, PatternBenefit benefit)
415     : RewritePattern(opName, benefit, context, {}), filter(filter),
416       options(options) {}
417 
418 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
419     Operation *op, PatternRewriter &rewriter) const {
420   if (failed(filter.checkAndNotify(rewriter, op)))
421     return failure();
422   if (failed(promoteSubviewsPrecondition(op, options)))
423     return failure();
424 
425   // TODO: We cannot use root update here. This pattern is creating other ops,
426   // so if the promotion fails, those need to be cleaned up, which doesnt seem
427   // to be happening here. So to fail properly, we should be cloning the op and
428   // deleting the previous op. This needs more investigation.
429   rewriter.startRootUpdate(op);
430   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
431   if (!promotedOp) {
432     rewriter.cancelRootUpdate(op);
433     return op->emitError("subview promotion failed");
434   }
435   rewriter.finalizeRootUpdate(op);
436   filter.replaceLinalgTransformationFilter(rewriter, op);
437   return success();
438 }
439 
440 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
441     MLIRContext *context, LinalgTransformationFilter filter,
442     PatternBenefit benefit)
443     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
444 
445 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
446     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
447     PatternBenefit benefit)
448     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
449 
450 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
451     Operation *op, PatternRewriter &rewriter) const {
452   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
453   if (!linalgOp)
454     return failure();
455   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
456     return failure();
457   SmallVector<Value> newResults;
458   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
459     return failure();
460   if (!newResults.empty())
461     rewriter.replaceOp(op, newResults);
462   else
463     rewriter.eraseOp(op);
464   return success();
465 }
466 
467 LogicalResult mlir::linalg::applyStagedPatterns(
468     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
469     const FrozenRewritePatternSet &stage2Patterns,
470     function_ref<LogicalResult(Operation *)> stage3Lambda) {
471   unsigned iteration = 0;
472   (void)iteration;
473   for (const auto &patterns : stage1Patterns) {
474     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
475                       << *op);
476     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
477       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
478       return failure();
479     }
480     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
481                       << *op);
482     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
483       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
484       return failure();
485     }
486     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
487                       << *op);
488     if (stage3Lambda) {
489       if (failed(stage3Lambda(op)))
490         return failure();
491       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
492                         << *op);
493     }
494   }
495   return success();
496 }
497 
498 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
499   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
500 }
501 
502 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
503 /// with pad_val) and GenericOp (to copy contents).
504 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
505     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
506 
507   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
508   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
509 
510   // Bail on non-static shapes.
511   if (!inputShapedType.hasStaticShape())
512     return failure();
513   if (!resultShapedType.hasStaticShape())
514     return failure();
515 
516   // Only support padding with a constant for now, i.e. either:
517   //   1. A BBarg from a different block.
518   //   2. A value defined outside of the current block.
519   Block &block = padOp.region().front();
520   auto yieldOp = cast<YieldOp>(block.getTerminator());
521   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
522   Value padValue = yieldOp.values().front();
523   Operation *definingOp = padValue.getDefiningOp();
524   if (definingOp && definingOp->getBlock() == &block)
525     return failure();
526   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
527     return failure();
528 
529   // Create tensor with the padded shape
530   Location loc = padOp.getLoc();
531   SmallVector<Value> indices(resultShapedType.getRank(),
532                              rewriter.create<ConstantIndexOp>(loc, 0));
533   Value initTensor = rewriter.create<InitTensorOp>(
534       loc, resultShapedType.getShape(), resultShapedType.getElementType());
535 
536   // Initialize tensor with the pad value
537   Value tmpTensor =
538       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
539 
540   // Copy original contents into new tensor
541   // Uses linalg.generic, but could be done with tensor.insert_slice
542   SmallVector<AffineExpr, 4> outputExprs;
543   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
544     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
545                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
546   }
547 
548   SmallVector<AffineMap, 2> transferMaps = {
549       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
550       AffineMap::get(resultShapedType.getRank(),
551                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
552 
553   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
554       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
555       getNParallelLoopsAttrs(resultShapedType.getRank()),
556       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
557         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
558       });
559 
560   return success();
561 }
562 
563 /// Filling `dest` using FillOp constant padding value if possible.
564 /// Otherwise, generate a tensor::GenerateOp.
565 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
566     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
567     const SmallVector<Value> &dynSizes) const {
568   auto padValue = padOp.getConstantPaddingValue();
569   if (padValue)
570     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
571 
572   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
573   auto generateOp = rewriter.create<tensor::GenerateOp>(
574       padOp.getLoc(), padOp.getResultType(), dynSizes);
575   // Copy region to new op.
576   BlockAndValueMapping bvm;
577   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
578   // Rewrite linalg::YieldOp to tensor::YieldOp.
579   OpBuilder::InsertionGuard guard(rewriter);
580   auto yieldOp =
581       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
582   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
583   assert(yieldOp.values().size() == 1);
584   rewriter.setInsertionPoint(yieldOp);
585   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
586   return generateOp;
587 }
588 
589 LogicalResult
590 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
591                                               PatternRewriter &rewriter) const {
592   // Given an OpFoldResult, return an index-typed value.
593   auto getIdxValue = [&](OpFoldResult ofr) {
594     if (auto val = ofr.dyn_cast<Value>())
595       return val;
596     return rewriter
597         .create<ConstantIndexOp>(
598             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
599         .getResult();
600   };
601 
602   auto resultType = padOp.getResultType();
603   // Compute size of InitTensorOp. Any combination of static/dynamic is
604   // supported.
605   SmallVector<Value> dynSizes;
606   SmallVector<int64_t> staticSizes;
607   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
608     if (resultType.isDynamicDim(dim)) {
609       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
610                                                           padOp.source(), dim);
611       // Add low and high padding value.
612       auto plusLow = rewriter.createOrFold<AddIOp>(
613           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
614       auto plusHigh = rewriter.createOrFold<AddIOp>(
615           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
616       dynSizes.push_back(plusHigh);
617     }
618     staticSizes.push_back(resultType.getDimSize(dim));
619   }
620 
621   // Init tensor and fill it with padding.
622   Value init = rewriter.create<InitTensorOp>(
623       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
624   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
625 
626   // Try optimize the copy of source.
627   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
628     return success();
629 
630   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
631   // for copying the PadOp source.
632   auto sourceType = padOp.getSourceType();
633   // Compute size of source of PadTensorOp.
634   SmallVector<OpFoldResult> srcSizes;
635   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
636     if (sourceType.isDynamicDim(dim)) {
637       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
638           padOp.getLoc(), padOp.source(), dim));
639     } else {
640       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
641     }
642   }
643   // Strides of InsertSliceOp are all 1.
644   SmallVector<OpFoldResult> strides(sourceType.getRank(),
645                                     rewriter.getIndexAttr(1));
646   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
647       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
648 
649   return success();
650 }
651 
652 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
653     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
654   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
655   if (!padOp)
656     return failure();
657   // Only unit stride supported.
658   if (!sliceOp.hasUnitStride())
659     return failure();
660 
661   Operation *tiledPadOp = padOp.getTiledImplementation(
662       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
663       sliceOp.getMixedSizes());
664   // All shapes are static and the data source is actually used. Rewrite into
665   // pad_tensor(subtensor(x)).
666   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
667   return success();
668 }
669