1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <type_traits>
32 
33 #define DEBUG_TYPE "linalg-transforms"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39 
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45     "__internal_linalg_transform__";
46 
47 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
48     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
49     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50       replacement(replacement) {}
51 
52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
53     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
54     Optional<Identifier> replacement)
55     : filters(),
56       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement) {
58   if (f)
59     filters.push_back(f);
60 }
61 
62 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
63     PatternRewriter &rewriter, Operation *op) const {
64   if (llvm::any_of(filters,
65                    [&](const FilterFunction &f) { return failed(f(op)); }))
66     return failure();
67 
68   auto attr = op->template getAttrOfType<StringAttr>(
69       LinalgTransforms::kLinalgTransformMarker);
70 
71   if (!attr) {
72     // 1. Has no filter case and matchDisjunction is empty.
73     if (matchDisjunction.empty())
74       return success();
75 
76     // 2. Has no filter but was expecting a filter.
77     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
78       diag << " does not have any filter from list: ";
79       interleaveComma(matchDisjunction, diag);
80     });
81   }
82 
83   // 4. Match explicit filter.
84   for (auto filter : matchDisjunction)
85     if (attr.getValue() == filter)
86       return success();
87 
88   // 5. Fail to match.
89   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
90     diag << " does not have any filter from list: ";
91     interleaveComma(matchDisjunction, diag);
92   });
93 }
94 
95 void mlir::linalg::LinalgTransformationFilter::
96     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
97                                       Operation *op) const {
98   if (replacement.hasValue())
99     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
100                 rewriter.getStringAttr(replacement.getValue().strref()));
101   else
102     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
103                                    rewriter.getContext()));
104 }
105 
106 LinalgTilingOptions &
107 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
108   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
109   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
110     OpBuilder::InsertionGuard guard(b);
111     b.setInsertionPointToStart(
112         &op->getParentOfType<FuncOp>().getBody().front());
113     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
114       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
115       return v;
116     }));
117   };
118   return *this;
119 }
120 
121 /// Try to compute a static bounding box for `operand`
122 /// Return success if either:
123 ///   1. The operand is already statically shaped, `result` is left unchanged.
124 ///   2. The operand is (partially) dynamic, `result` is the result of a freshly
125 ///      created PadTensorOp.
126 /// Return failure if the operand cannot be padded to a static shape.
127 static LogicalResult padOperandToSmallestStaticBoundingBox(
128     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
129     const PaddingValueComputationFunction &paddingFunc, Value &result) {
130   // Already static shape, no need to pad.
131   if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
132     return success();
133   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
134   // Not a slice op, cannot construct a static bounding box.
135   if (!sliceOp)
136     return failure();
137   SmallVector<int64_t> staticSizes;
138   staticSizes.reserve(opToPad.getRank(opOperand));
139   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
140   for (auto size : shapedOp.getMixedSizes()) {
141     auto indexAttr = size.is<Attribute>()
142                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
143                          : linalg::getSmallestBoundingIndex(size.get<Value>());
144     // SmallestBoundingIndex must exist for all sizes.
145     // For now return an error if we can't find it.
146     if (!indexAttr)
147       return rewriter.notifyMatchFailure(
148           opToPad, "No constant bounding box can be found for padding");
149     staticSizes.push_back(indexAttr.getInt());
150   }
151   Value pad = paddingFunc(rewriter, *opOperand);
152   auto staticTensorType = RankedTensorType::get(
153       staticSizes, getElementTypeOrSelf(opOperand->get()));
154   result = linalg::PadTensorOp::createPadHighOp(
155       staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
156   return success();
157 }
158 
159 LogicalResult
160 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
161                           const PaddingValueComputationFunction &paddingFunc,
162                           LinalgOp &paddedOp) {
163   Location loc = opToPad->getLoc();
164 
165   // If the op is fully static, it does not need padding.
166   // TODO: there are cases where we may still want to pad to larger sizes.
167   assert(opToPad.hasTensorSemantics() &&
168          "expected operation to have tensor semantics");
169   if (!opToPad.hasDynamicShape())
170     return success();
171 
172   OpBuilder::InsertionGuard g(rewriter);
173   // Set IP after op because we also take the dims of the original output.
174   rewriter.setInsertionPointAfter(opToPad);
175   // Make a copy of the shaped operands and update it.
176   SmallVector<Value> newOperands;
177   newOperands.reserve(opToPad.getNumInputsAndOutputs());
178   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
179     Value paddedOperand;
180     // If padding was requested but the shape cannot be bounded statically then
181     // the pattern fails to apply.
182     if (failed(padOperandToSmallestStaticBoundingBox(
183             rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
184       return failure();
185     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
186   }
187 
188   // Clone `opToPad` to operate on the statically padded shapes.
189   auto resultTensorTypes =
190       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
191   paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
192 
193   // Recover the slice out of the new static results. This keeps the original
194   // linalg op around because it uses the dims of the original results.
195   // This later folds away.
196   SmallVector<Value> paddedSubviewResults;
197   paddedSubviewResults.reserve(opToPad->getNumResults());
198   SetVector<Operation *> newUsersOfOpToPad;
199   for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
200     auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
201     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
202     auto sizes = llvm::to_vector<4>(llvm::map_range(
203         llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
204           auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d);
205           newUsersOfOpToPad.insert(dimOp);
206           return dimOp.getResult();
207         }));
208     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
209     paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
210         loc, std::get<1>(it), offsets, sizes, strides));
211   }
212   // Replace the transient `opToPad` locally, except for uses that we just
213   // created for the purpose of extracting the dims.
214   rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
215     return !newUsersOfOpToPad.contains(opOp.getOwner());
216   });
217   return success();
218 }
219 
220 /// Linalg base tiling pattern.
221 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
222     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
223     LinalgTransformationFilter filter, PatternBenefit benefit)
224     : RewritePattern(opName, benefit, context), filter(filter),
225       options(options) {}
226 
227 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
228     MLIRContext *context, LinalgTilingOptions options,
229     LinalgTransformationFilter filter, PatternBenefit benefit)
230     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
231       options(options) {}
232 
233 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
234     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
235   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
236   if (!linalgOp)
237     return failure();
238   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
239     return failure();
240 
241   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
242 
243   if (!res)
244     return failure();
245 
246   // Setup RAII guard to return properly.
247   LinalgOp tiledOp = res->op;
248   auto guard = llvm::make_scope_exit([&]() {
249     // Return relevant information to derived pattern.
250     result = *res;
251     // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
252     filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
253     if (tiledOp != res->op)
254       filter.replaceLinalgTransformationFilter(rewriter, res->op);
255   });
256 
257   // Consider padding on the fly only if the op has tensor semantics.
258   if (!options.paddingValueComputationFunction ||
259       !linalgOp.hasTensorSemantics())
260     return success();
261 
262   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
263   // `res.op` is rewritten in static form with padded operands.
264   LinalgOp paddedOp;
265   if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
266                                   options.paddingValueComputationFunction,
267                                   paddedOp))) {
268     res->op = paddedOp;
269     // Do not perform replacement of `linalgOp`, let the derived patterns
270     // do this as they see fit, from the resulting TiledLinalgOp.
271     return success();
272   }
273   // Set so RAII guard does not propagate TiledLinalgOp to `result`.
274   return failure();
275 }
276 
277 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
278   if (tiledOp.loops.empty())
279     return tiledOp.op.getOperation()->getResults();
280   return tiledOp.loops.front()->getResults();
281 }
282 
283 static ValueRange
284 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
285   if (tiledAndFusedOp.fusedLoops.empty())
286     return tiledAndFusedOp.op.getOperation()->getResults();
287   return tiledAndFusedOp.fusedLoops.front()->getResults();
288 }
289 
290 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
291     StringRef opName, MLIRContext *context,
292     const LinalgDependenceGraph &dependenceGraph,
293     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
294     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
295     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
296     : RewritePattern(opName, benefit, context, {}),
297       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
298       fusionOptions(fusionOptions), filter(filter),
299       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
300 
301 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
302     Operation *op, PatternRewriter &rewriter) const {
303   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
304   // TODO: remove hasIndexSemantics check once index ops are supported.
305   if (!linalgOp || linalgOp.hasIndexSemantics())
306     return failure();
307   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
308     return failure();
309 
310   DenseSet<Operation *> producers;
311   producers.insert(linalgOp);
312   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
313     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
314     // When looking at dependences into, indexingOp is always OpOperand. We
315     // could assert, but continue if this is not the case.
316     if (!operandNumber)
317       continue;
318     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
319       continue;
320     if (isa<LinalgOp>(dependence.getDependentOp()))
321       producers.insert(dependence.getDependentOp());
322   }
323 
324   SmallVector<LinalgOp, 1> fusionOps;
325   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
326        ++it) {
327     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
328     if (producerLinalgOp && producers.count(producerLinalgOp))
329       fusionOps.push_back(producerLinalgOp);
330   }
331   fusionOps.push_back(linalgOp);
332 
333   SmallVector<Value, 4> tileSizes =
334       tilingOptions.tileSizeComputationFunction(rewriter, op);
335   LinalgTilingOptions instanceTilingOptions = tilingOptions;
336   instanceTilingOptions.setTileSizes(tileSizes);
337   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
338       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
339   if (!tiledAndFusedOps)
340     return failure();
341 
342   // Tile the unfused loops;
343   SmallVector<Value, 4> unfusedLoopTileSizes;
344   Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
345   for (auto tileSize : enumerate(tileSizes)) {
346     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
347       unfusedLoopTileSizes.push_back(zero);
348     else
349       unfusedLoopTileSizes.push_back(tileSize.value());
350   }
351   // Tile the loop only if there is a non-zero tile size.
352   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
353     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
354   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
355         if (auto cst = val.getDefiningOp<ConstantIndexOp>())
356           return cst.getValue() != 0;
357         return true;
358       })) {
359     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
360     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
361     Optional<TiledLinalgOp> unfusedTiledOp =
362         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
363     if (!unfusedTiledOp)
364       return failure();
365     rewriter.replaceOp(tiledAndFusedOps->op,
366                        getTiledOpResult(unfusedTiledOp.getValue()));
367     tiledAndFusedOps->op = unfusedTiledOp->op;
368   }
369   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
370 
371   filter.replaceLinalgTransformationFilter(rewriter,
372                                            tiledAndFusedOps->op.getOperation());
373   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
374     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
375                                                     fusedOp.getOperation());
376   }
377   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
378     originalOpMarker.replaceLinalgTransformationFilter(
379         rewriter, origProducerOp.getOperation());
380   }
381   rewriter.updateRootInPlace(op, [&]() {
382     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
383   });
384   return success();
385 }
386 
387 /// Linalg generic interchange pattern.
388 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
389     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
390     LinalgTransformationFilter filter, PatternBenefit benefit)
391     : OpRewritePattern(context, benefit), filter(filter),
392       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
393 
394 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
395     GenericOp genericOp, PatternRewriter &rewriter) const {
396   if (failed(filter.checkAndNotify(rewriter, genericOp)))
397     return failure();
398   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
399     return failure();
400 
401   // TODO: figure out how this interplays with named ops. In particular this
402   // should break the named op property.
403   rewriter.updateRootInPlace(genericOp, [&]() {
404     interchangeGenericOp(rewriter, genericOp, interchangeVector);
405     // New filter if specified.
406     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
407   });
408   return success();
409 }
410 
411 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
412     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
413     LinalgTransformationFilter filter, PatternBenefit benefit)
414     : RewritePattern(opName, benefit, context, {}), filter(filter),
415       options(options) {}
416 
417 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
418     Operation *op, PatternRewriter &rewriter) const {
419   if (failed(filter.checkAndNotify(rewriter, op)))
420     return failure();
421   if (failed(promoteSubviewsPrecondition(op, options)))
422     return failure();
423 
424   // TODO: We cannot use root update here. This pattern is creating other ops,
425   // so if the promotion fails, those need to be cleaned up, which doesnt seem
426   // to be happening here. So to fail properly, we should be cloning the op and
427   // deleting the previous op. This needs more investigation.
428   rewriter.startRootUpdate(op);
429   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
430   if (!promotedOp) {
431     rewriter.cancelRootUpdate(op);
432     return op->emitError("subview promotion failed");
433   }
434   rewriter.finalizeRootUpdate(op);
435   filter.replaceLinalgTransformationFilter(rewriter, op);
436   return success();
437 }
438 
439 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
440     MLIRContext *context, LinalgTransformationFilter filter,
441     PatternBenefit benefit)
442     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
443 
444 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
445     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
446     PatternBenefit benefit)
447     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
448 
449 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
450     Operation *op, PatternRewriter &rewriter) const {
451   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
452   if (!linalgOp)
453     return failure();
454   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
455     return failure();
456   SmallVector<Value> newResults;
457   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
458     return failure();
459   if (!newResults.empty())
460     rewriter.replaceOp(op, newResults);
461   else
462     rewriter.eraseOp(op);
463   return success();
464 }
465 
466 LogicalResult mlir::linalg::applyStagedPatterns(
467     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
468     const FrozenRewritePatternSet &stage2Patterns,
469     function_ref<LogicalResult(Operation *)> stage3Lambda) {
470   unsigned iteration = 0;
471   (void)iteration;
472   for (const auto &patterns : stage1Patterns) {
473     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
474                       << *op);
475     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
476       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
477       return failure();
478     }
479     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
480                       << *op);
481     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
482       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
483       return failure();
484     }
485     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
486                       << *op);
487     if (stage3Lambda) {
488       if (failed(stage3Lambda(op)))
489         return failure();
490       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
491                         << *op);
492     }
493   }
494   return success();
495 }
496 
497 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
498   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
499 }
500 
501 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
502 /// with pad_val) and GenericOp (to copy contents).
503 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
504     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
505 
506   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
507   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
508 
509   // Bail on non-static shapes.
510   if (!inputShapedType.hasStaticShape())
511     return failure();
512   if (!resultShapedType.hasStaticShape())
513     return failure();
514 
515   // Only support padding with a constant for now, i.e. either:
516   //   1. A BBarg from a different block.
517   //   2. A value defined outside of the current block.
518   Block &block = padOp.region().front();
519   auto yieldOp = cast<YieldOp>(block.getTerminator());
520   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
521   Value padValue = yieldOp.values().front();
522   Operation *definingOp = padValue.getDefiningOp();
523   if (definingOp && definingOp->getBlock() == &block)
524     return failure();
525   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
526     return failure();
527 
528   // Create tensor with the padded shape
529   Location loc = padOp.getLoc();
530   SmallVector<Value> indices(resultShapedType.getRank(),
531                              rewriter.create<ConstantIndexOp>(loc, 0));
532   Value initTensor = rewriter.create<InitTensorOp>(
533       loc, resultShapedType.getShape(), resultShapedType.getElementType());
534 
535   // Initialize tensor with the pad value
536   Value tmpTensor =
537       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
538 
539   // Copy original contents into new tensor
540   // Uses linalg.generic, but could be done with tensor.insert_slice
541   SmallVector<AffineExpr, 4> outputExprs;
542   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
543     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
544                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
545   }
546 
547   SmallVector<AffineMap, 2> transferMaps = {
548       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
549       AffineMap::get(resultShapedType.getRank(),
550                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
551 
552   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
553       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
554       getNParallelLoopsAttrs(resultShapedType.getRank()),
555       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
556         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
557       });
558 
559   return success();
560 }
561 
562 /// Filling `dest` using FillOp constant padding value if possible.
563 /// Otherwise, generate a tensor::GenerateOp.
564 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
565     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
566     const SmallVector<Value> &dynSizes) const {
567   auto padValue = padOp.getConstantPaddingValue();
568   if (padValue)
569     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
570 
571   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
572   auto generateOp = rewriter.create<tensor::GenerateOp>(
573       padOp.getLoc(), padOp.getResultType(), dynSizes);
574   // Copy region to new op.
575   BlockAndValueMapping bvm;
576   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
577   // Rewrite linalg::YieldOp to tensor::YieldOp.
578   OpBuilder::InsertionGuard guard(rewriter);
579   auto yieldOp =
580       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
581   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
582   assert(yieldOp.values().size() == 1);
583   rewriter.setInsertionPoint(yieldOp);
584   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
585   return generateOp;
586 }
587 
588 LogicalResult
589 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
590                                               PatternRewriter &rewriter) const {
591   // Given an OpFoldResult, return an index-typed value.
592   auto getIdxValue = [&](OpFoldResult ofr) {
593     if (auto val = ofr.dyn_cast<Value>())
594       return val;
595     return rewriter
596         .create<ConstantIndexOp>(
597             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
598         .getResult();
599   };
600 
601   auto resultType = padOp.getResultType();
602   // Compute size of InitTensorOp. Any combination of static/dynamic is
603   // supported.
604   SmallVector<Value> dynSizes;
605   SmallVector<int64_t> staticSizes;
606   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
607     if (resultType.isDynamicDim(dim)) {
608       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
609                                                           padOp.source(), dim);
610       // Add low and high padding value.
611       auto plusLow = rewriter.createOrFold<AddIOp>(
612           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
613       auto plusHigh = rewriter.createOrFold<AddIOp>(
614           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
615       dynSizes.push_back(plusHigh);
616     }
617     staticSizes.push_back(resultType.getDimSize(dim));
618   }
619 
620   // Init tensor and fill it with padding.
621   Value init = rewriter.create<InitTensorOp>(
622       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
623   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
624 
625   // Try optimize the copy of source.
626   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
627     return success();
628 
629   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
630   // for copying the PadOp source.
631   auto sourceType = padOp.getSourceType();
632   // Compute size of source of PadTensorOp.
633   SmallVector<OpFoldResult> srcSizes;
634   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
635     if (sourceType.isDynamicDim(dim)) {
636       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
637           padOp.getLoc(), padOp.source(), dim));
638     } else {
639       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
640     }
641   }
642   // Strides of InsertSliceOp are all 1.
643   SmallVector<OpFoldResult> strides(sourceType.getRank(),
644                                     rewriter.getIndexAttr(1));
645   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
646       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
647 
648   return success();
649 }
650 
651 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
652     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
653   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
654   if (!padOp)
655     return failure();
656   // Only unit stride supported.
657   if (!sliceOp.hasUnitStride())
658     return failure();
659 
660   Operation *tiledPadOp = padOp.getTiledImplementation(
661       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
662       sliceOp.getMixedSizes());
663   // All shapes are static and the data source is actually used. Rewrite into
664   // pad_tensor(subtensor(x)).
665   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
666   return success();
667 }
668