1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StaticValueUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 
36 #define DEBUG_TYPE "linalg-transforms"
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42 
43 //===----------------------------------------------------------------------===//
44 // Transformations exposed as rewrite patterns.
45 //===----------------------------------------------------------------------===//
46 // Marker used as attribute name in generated Linalg rewriting transformations.
47 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
48     "__internal_linalg_transform__";
49 
50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
51     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
52     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
53       replacement(replacement) {}
54 
55 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
56     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
57     Optional<Identifier> replacement)
58     : filters(),
59       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
60       replacement(replacement) {
61   if (f)
62     filters.push_back(f);
63 }
64 
65 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
66     PatternRewriter &rewriter, Operation *op) const {
67   if (llvm::any_of(filters,
68                    [&](const FilterFunction &f) { return failed(f(op)); }))
69     return failure();
70 
71   auto attr = op->template getAttrOfType<StringAttr>(
72       LinalgTransforms::kLinalgTransformMarker);
73 
74   if (!attr) {
75     // 1. Has no filter case and matchDisjunction is empty.
76     if (matchDisjunction.empty())
77       return success();
78 
79     // 2. Has no filter but was expecting a filter.
80     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
81       diag << " does not have any filter from list: ";
82       interleaveComma(matchDisjunction, diag);
83     });
84   }
85 
86   // 4. Match explicit filter.
87   for (auto filter : matchDisjunction)
88     if (attr.getValue() == filter)
89       return success();
90 
91   // 5. Fail to match.
92   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
93     diag << " does not have any filter from list: ";
94     interleaveComma(matchDisjunction, diag);
95   });
96 }
97 
98 void mlir::linalg::LinalgTransformationFilter::
99     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
100                                       Operation *op) const {
101   if (replacement.hasValue())
102     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
103                 rewriter.getStringAttr(replacement.getValue().strref()));
104   else
105     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
106                                    rewriter.getContext()));
107 }
108 
109 LinalgTilingOptions &
110 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
111   assert(!tileSizeComputationFunction && "tile sizes already set");
112   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
113   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
114     OpBuilder::InsertionGuard guard(b);
115     b.setInsertionPointToStart(
116         &op->getParentOfType<FuncOp>().getBody().front());
117     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
118       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
119       return v;
120     }));
121   };
122   return *this;
123 }
124 
125 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
126   assert(!tileSizeComputationFunction && "tile sizes already set");
127   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
128     SmallVector<Value, 4> tileSizes;
129     auto linalgOp = dyn_cast<LinalgOp>(op);
130     if (!linalgOp)
131       return tileSizes;
132     Location loc = linalgOp.getLoc();
133     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
134     AffineMap map = linalgOp.getShapesToLoopsMap();
135     if (!map)
136       return tileSizes;
137     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
138     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
139     // size 0).
140     for (Value shapeSize : shapeSizes)
141       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
142                               ? b.create<arith::ConstantIndexOp>(loc, 0)
143                               : b.create<arith::ConstantIndexOp>(loc, 1));
144     return tileSizes;
145   };
146   return *this;
147 }
148 
149 /// Helper function that tries to pad `opOperand`. Exit early and return success
150 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
151 /// pad the operand even if it already has a static shape. Set `result` to the
152 /// result of the created PadTensorOp or return failure if the operand cannot be
153 /// padded to a static shape.
154 static LogicalResult padOperandToSmallestStaticBoundingBox(
155     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
156     const PaddingValueComputationFunction &paddingFunc, Value &result) {
157   // Can't pad scalars.
158   if (opToPad.getShape(opOperand).empty())
159     return success();
160   // Can't pad if no padding value is known.
161   FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand);
162   if (failed(paddingValue))
163     return success();
164   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
165   // Not a slice op, cannot construct a static bounding box.
166   if (!sliceOp)
167     return failure();
168   SmallVector<int64_t> staticSizes;
169   staticSizes.reserve(opToPad.getRank(opOperand));
170   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
171   for (auto size : shapedOp.getMixedSizes()) {
172     auto indexAttr = size.is<Attribute>()
173                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
174                          : linalg::getSmallestBoundingIndex(size.get<Value>());
175     // SmallestBoundingIndex must exist for all sizes.
176     // For now return an error if we can't find it.
177     if (!indexAttr)
178       return rewriter.notifyMatchFailure(
179           opToPad, "No constant bounding box can be found for padding");
180     staticSizes.push_back(indexAttr.getInt());
181   }
182   auto staticTensorType = RankedTensorType::get(
183       staticSizes, getElementTypeOrSelf(opOperand->get()));
184   result = linalg::PadTensorOp::createPadHighOp(
185       staticTensorType, opOperand->get(), paddingValue.getValue(),
186       /*nofold=*/true, opToPad->getLoc(), rewriter);
187   return success();
188 }
189 
190 LogicalResult
191 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
192                           const PaddingValueComputationFunction &paddingFunc,
193                           LinalgOp &paddedOp) {
194   Location loc = opToPad->getLoc();
195 
196   // TODO: there are cases where we may still want to pad to larger sizes.
197   assert(opToPad.hasTensorSemantics() &&
198          "expected operation to have tensor semantics");
199 
200   OpBuilder::InsertionGuard g(rewriter);
201   // Set IP after op because we also take the dims of the original output.
202   rewriter.setInsertionPointAfter(opToPad);
203   // Make a copy of the shaped operands and update it.
204   SmallVector<Value> newOperands;
205   newOperands.reserve(opToPad.getNumInputsAndOutputs());
206   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
207     Value paddedOperand;
208     // If padding was requested but the shape cannot be bounded statically then
209     // the pattern fails to apply.
210     if (failed(padOperandToSmallestStaticBoundingBox(
211             rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
212       return failure();
213     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
214   }
215 
216   SmallVector<SmallVector<Value>> reifiedResultShapes;
217   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
218                  .reifyResultShapes(rewriter, reifiedResultShapes)))
219     return failure();
220   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
221          "expected same number of results");
222 
223   // Clone `opToPad` to operate on the statically padded shapes.
224   auto resultTensorTypes =
225       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
226   paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
227 
228   // Recover the slice out of the new static results. This keeps the original
229   // linalg op around because it uses the dims of the original results.
230   SmallVector<Value> paddedSubviewResults;
231   paddedSubviewResults.reserve(opToPad->getNumResults());
232   for (auto en : llvm::enumerate(paddedOp->getResults())) {
233     Value paddedResult = en.value();
234     int64_t resultNumber = en.index();
235     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
236     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
237     SmallVector<OpFoldResult> sizes;
238     for (Value v : reifiedResultShapes[resultNumber])
239       sizes.push_back(v);
240     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
241     paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
242         loc, paddedResult, offsets, sizes, strides));
243   }
244   rewriter.replaceOp(opToPad, paddedSubviewResults);
245   return success();
246 }
247 
248 /// Linalg base tiling pattern.
249 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
250     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
251     LinalgTransformationFilter filter, PatternBenefit benefit)
252     : RewritePattern(opName, benefit, context), filter(filter),
253       options(options) {}
254 
255 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
256     MLIRContext *context, LinalgTilingOptions options,
257     LinalgTransformationFilter filter, PatternBenefit benefit)
258     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
259       options(options) {}
260 
261 /// Try to peel a loop `op` and return the new result.
262 // TODO: Add support for scf.parallel and affine.for loops.
263 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
264   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
265       .Case<scf::ForOp>([&](scf::ForOp forOp) {
266         scf::ForOp partialIteration;
267         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
268                                                       partialIteration)))
269           return partialIteration->getResults();
270         assert(!partialIteration && "expected that loop was not peeled");
271         return forOp->getResults();
272       })
273       .Default([&](Operation *op) { return op->getResults(); });
274 }
275 
276 /// Try to peel a TiledLoopOp and return the new result.
277 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
278                                       TiledLoopOp tiledLoop, int64_t idx) {
279   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
280          "requested peeling of non-existing loop");
281   TiledLoopOp result;
282   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
283     return result->getResults();
284   assert(!result && "expected that loop was not peeled");
285   return tiledLoop->getResults();
286 }
287 
288 /// Peel loops after tiling.
289 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
290                       const LinalgTilingOptions &options) {
291   for (int64_t loop : options.peeledLoops) {
292     assert(loop < static_cast<int64_t>(res.loops.size()) &&
293            "requested peeling of non-existing loop");
294     SmallVector<Value, 4> loopResults;
295     Operation *loopOp = res.loops[loop];
296     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
297       assert(llvm::all_of(
298                  res.loops,
299                  [&](Operation *op) { return op == res.loops.front(); }) &&
300              "expected that all loop ops are the same TiledLoopOp");
301       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
302       assert(tiledLoopOp && "expected TiledLoopOp");
303       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
304     } else {
305       loopResults = peelLoop(rewriter, loopOp);
306     }
307 
308     // The result of the loop nest may change with peeling.
309     if (res.tensorResults.size() == loopOp->getNumResults() &&
310         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
311                    loopOp->getResults().begin()))
312       res.tensorResults = loopResults;
313   }
314 }
315 
316 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
317     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
318   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
319   if (!linalgOp)
320     return failure();
321   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
322     return failure();
323 
324   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
325 
326   if (!res)
327     return failure();
328   // Clear filter to stop recursive pattern application.
329   filter.replaceLinalgTransformationFilter(rewriter, res->op);
330 
331   // Peel loops.
332   peelLoops(rewriter, *res, options);
333 
334   // Consider padding on the fly only if the op has tensor semantics.
335   if (!options.paddingValueComputationFunction ||
336       !linalgOp.hasTensorSemantics()) {
337     result = *res;
338     return success();
339   }
340 
341   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
342   // `res.op` is rewritten in static form with padded operands.
343   LinalgOp paddedOp;
344   if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
345                                   options.paddingValueComputationFunction,
346                                   paddedOp))) {
347     filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
348     res->op = paddedOp;
349     result = *res;
350     // Do not perform replacement of `linalgOp`, let the derived patterns
351     // do this as they see fit, from the resulting TiledLinalgOp.
352     return success();
353   }
354   // Set so RAII guard does not propagate TiledLinalgOp to `result`.
355   return failure();
356 }
357 
358 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
359   if (tiledOp.loops.empty())
360     return tiledOp.op.getOperation()->getResults();
361   return tiledOp.loops.front()->getResults();
362 }
363 
364 static ValueRange
365 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
366   if (tiledAndFusedOp.fusedLoops.empty())
367     return tiledAndFusedOp.op.getOperation()->getResults();
368   return tiledAndFusedOp.fusedLoops.front()->getResults();
369 }
370 
371 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
372     StringRef opName, MLIRContext *context,
373     const LinalgDependenceGraph &dependenceGraph,
374     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
375     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
376     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
377     : RewritePattern(opName, benefit, context, {}),
378       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
379       fusionOptions(fusionOptions), filter(filter),
380       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
381 
382 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
383     Operation *op, PatternRewriter &rewriter) const {
384   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
385   // TODO: remove hasIndexSemantics check once index ops are supported.
386   if (!linalgOp || linalgOp.hasIndexSemantics())
387     return failure();
388   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
389     return failure();
390 
391   DenseSet<Operation *> producers;
392   producers.insert(linalgOp);
393   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
394     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
395     // When looking at dependences into, indexingOp is always OpOperand. We
396     // could assert, but continue if this is not the case.
397     if (!operandNumber)
398       continue;
399     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
400       continue;
401     if (isa<LinalgOp>(dependence.getDependentOp()))
402       producers.insert(dependence.getDependentOp());
403   }
404 
405   SmallVector<LinalgOp, 1> fusionOps;
406   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
407        ++it) {
408     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
409     if (producerLinalgOp && producers.count(producerLinalgOp))
410       fusionOps.push_back(producerLinalgOp);
411   }
412   fusionOps.push_back(linalgOp);
413 
414   SmallVector<Value, 4> tileSizes =
415       tilingOptions.tileSizeComputationFunction(rewriter, op);
416   LinalgTilingOptions instanceTilingOptions = tilingOptions;
417   instanceTilingOptions.setTileSizes(tileSizes);
418   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
419       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
420   if (!tiledAndFusedOps)
421     return failure();
422 
423   // Tile the unfused loops;
424   SmallVector<Value, 4> unfusedLoopTileSizes;
425   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
426   for (auto tileSize : enumerate(tileSizes)) {
427     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
428       unfusedLoopTileSizes.push_back(zero);
429     else
430       unfusedLoopTileSizes.push_back(tileSize.value());
431   }
432   // Tile the loop only if there is a non-zero tile size.
433   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
434     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
435   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
436         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
437           return cst.value() != 0;
438         return true;
439       })) {
440     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
441     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
442     Optional<TiledLinalgOp> unfusedTiledOp =
443         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
444     if (!unfusedTiledOp)
445       return failure();
446     rewriter.replaceOp(tiledAndFusedOps->op,
447                        getTiledOpResult(unfusedTiledOp.getValue()));
448     tiledAndFusedOps->op = unfusedTiledOp->op;
449   }
450   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
451 
452   filter.replaceLinalgTransformationFilter(rewriter,
453                                            tiledAndFusedOps->op.getOperation());
454   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
455     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
456                                                     fusedOp.getOperation());
457   }
458   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
459     originalOpMarker.replaceLinalgTransformationFilter(
460         rewriter, origProducerOp.getOperation());
461   }
462   rewriter.updateRootInPlace(op, [&]() {
463     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
464   });
465   return success();
466 }
467 
468 /// Linalg generic interchange pattern.
469 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
470     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
471     LinalgTransformationFilter filter, PatternBenefit benefit)
472     : OpRewritePattern(context, benefit), filter(filter),
473       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
474 
475 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
476     GenericOp genericOp, PatternRewriter &rewriter) const {
477   if (failed(filter.checkAndNotify(rewriter, genericOp)))
478     return failure();
479   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
480     return failure();
481 
482   // TODO: figure out how this interplays with named ops. In particular this
483   // should break the named op property.
484   rewriter.updateRootInPlace(genericOp, [&]() {
485     interchangeGenericOp(rewriter, genericOp, interchangeVector);
486     // New filter if specified.
487     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
488   });
489   return success();
490 }
491 
492 /// Linalg generalization pattern.
493 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
494     MLIRContext *context, LinalgTransformationFilter filter,
495     PatternBenefit benefit)
496     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
497 
498 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
499     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
500     PatternBenefit benefit)
501     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
502 
503 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
504     Operation *op, PatternRewriter &rewriter) const {
505   if (failed(filter.checkAndNotify(rewriter, op)))
506     return failure();
507   if (failed(generalizeNamedOpPrecondition(op)))
508     return failure();
509 
510   GenericOp genericOp = generalizeNamedOp(rewriter, op);
511   rewriter.replaceOp(op, genericOp.getResults());
512   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
513   return success();
514 }
515 
516 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
517     MLIRContext *context, LinalgTransformationFilter filter,
518     LinalgPromotionOptions options, PatternBenefit benefit)
519     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
520       options(options) {}
521 
522 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
523     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
524     LinalgTransformationFilter filter, PatternBenefit benefit)
525     : RewritePattern(opName, benefit, context, {}), filter(filter),
526       options(options) {}
527 
528 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
529     Operation *op, PatternRewriter &rewriter) const {
530   if (failed(filter.checkAndNotify(rewriter, op)))
531     return failure();
532   if (failed(promoteSubviewsPrecondition(op, options)))
533     return failure();
534 
535   // TODO: We cannot use root update here. This pattern is creating other ops,
536   // so if the promotion fails, those need to be cleaned up, which doesnt seem
537   // to be happening here. So to fail properly, we should be cloning the op and
538   // deleting the previous op. This needs more investigation.
539   rewriter.startRootUpdate(op);
540   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
541   if (!promotedOp) {
542     rewriter.cancelRootUpdate(op);
543     return op->emitError("subview promotion failed");
544   }
545   rewriter.finalizeRootUpdate(op);
546   filter.replaceLinalgTransformationFilter(rewriter, op);
547   return success();
548 }
549 
550 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
551     MLIRContext *context, LinalgTransformationFilter filter,
552     PatternBenefit benefit)
553     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
554 
555 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
556     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
557     PatternBenefit benefit)
558     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
559 
560 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
561     Operation *op, PatternRewriter &rewriter) const {
562   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
563   if (!linalgOp)
564     return failure();
565   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
566     return failure();
567   SmallVector<Value> newResults;
568   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
569     return failure();
570   if (!newResults.empty())
571     rewriter.replaceOp(op, newResults);
572   else
573     rewriter.eraseOp(op);
574   return success();
575 }
576 
577 LogicalResult mlir::linalg::applyStagedPatterns(
578     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
579     const FrozenRewritePatternSet &stage2Patterns,
580     function_ref<LogicalResult(Operation *)> stage3Lambda) {
581   unsigned iteration = 0;
582   (void)iteration;
583   for (const auto &patterns : stage1Patterns) {
584     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
585                       << *op);
586     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
587       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
588       return failure();
589     }
590     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
591                       << *op);
592     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
593       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
594       return failure();
595     }
596     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
597                       << *op);
598     if (stage3Lambda) {
599       if (failed(stage3Lambda(op)))
600         return failure();
601       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
602                         << *op);
603     }
604   }
605   return success();
606 }
607 
608 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
609   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
610 }
611 
612 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
613 /// with pad_val) and GenericOp (to copy contents).
614 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
615     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
616 
617   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
618   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
619 
620   // Bail on non-static shapes.
621   if (!inputShapedType.hasStaticShape())
622     return failure();
623   if (!resultShapedType.hasStaticShape())
624     return failure();
625 
626   // Only support padding with a constant for now, i.e. either:
627   //   1. A BBarg from a different block.
628   //   2. A value defined outside of the current block.
629   Block &block = padOp.region().front();
630   auto yieldOp = cast<YieldOp>(block.getTerminator());
631   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
632   Value padValue = yieldOp.values().front();
633   Operation *definingOp = padValue.getDefiningOp();
634   if (definingOp && definingOp->getBlock() == &block)
635     return failure();
636   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
637     return failure();
638 
639   // Create tensor with the padded shape
640   Location loc = padOp.getLoc();
641   SmallVector<Value> indices(resultShapedType.getRank(),
642                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
643   Value initTensor = rewriter.create<InitTensorOp>(
644       loc, resultShapedType.getShape(), resultShapedType.getElementType());
645 
646   // Initialize tensor with the pad value
647   Value tmpTensor =
648       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
649 
650   // Copy original contents into new tensor
651   // Uses linalg.generic, but could be done with tensor.insert_slice
652   SmallVector<AffineExpr, 4> outputExprs;
653   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
654     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
655                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
656   }
657 
658   SmallVector<AffineMap, 2> transferMaps = {
659       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
660       AffineMap::get(resultShapedType.getRank(),
661                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
662 
663   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
664       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
665       getNParallelLoopsAttrs(resultShapedType.getRank()),
666       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
667         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
668       });
669 
670   return success();
671 }
672 
673 /// Filling `dest` using FillOp constant padding value if possible.
674 /// Otherwise, generate a tensor::GenerateOp.
675 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
676     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
677     const SmallVector<Value> &dynSizes) const {
678   auto padValue = padOp.getConstantPaddingValue();
679   if (padValue)
680     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
681 
682   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
683   auto generateOp = rewriter.create<tensor::GenerateOp>(
684       padOp.getLoc(), padOp.getResultType(), dynSizes);
685   // Copy region to new op.
686   BlockAndValueMapping bvm;
687   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
688   // Rewrite linalg::YieldOp to tensor::YieldOp.
689   OpBuilder::InsertionGuard guard(rewriter);
690   auto yieldOp =
691       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
692   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
693   assert(yieldOp.values().size() == 1);
694   rewriter.setInsertionPoint(yieldOp);
695   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
696   return generateOp;
697 }
698 
699 LogicalResult
700 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
701                                               PatternRewriter &rewriter) const {
702   // Given an OpFoldResult, return an index-typed value.
703   auto getIdxValue = [&](OpFoldResult ofr) {
704     if (auto val = ofr.dyn_cast<Value>())
705       return val;
706     return rewriter
707         .create<arith::ConstantIndexOp>(
708             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
709         .getResult();
710   };
711 
712   auto resultType = padOp.getResultType();
713   // Compute size of InitTensorOp. Any combination of static/dynamic is
714   // supported.
715   SmallVector<Value> dynSizes;
716   SmallVector<int64_t> staticSizes;
717   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
718     if (resultType.isDynamicDim(dim)) {
719       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
720                                                           padOp.source(), dim);
721       // Add low and high padding value.
722       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
723           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
724       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
725           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
726       dynSizes.push_back(plusHigh);
727     }
728     staticSizes.push_back(resultType.getDimSize(dim));
729   }
730 
731   // Init tensor and fill it with padding.
732   Value init = rewriter.create<InitTensorOp>(
733       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
734   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
735 
736   // Try optimize the copy of source.
737   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
738     return success();
739 
740   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
741   // for copying the PadOp source.
742   auto sourceType = padOp.getSourceType();
743   // Compute size of source of PadTensorOp.
744   SmallVector<OpFoldResult> srcSizes;
745   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
746     if (sourceType.isDynamicDim(dim)) {
747       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
748           padOp.getLoc(), padOp.source(), dim));
749     } else {
750       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
751     }
752   }
753   // Strides of InsertSliceOp are all 1.
754   SmallVector<OpFoldResult> strides(sourceType.getRank(),
755                                     rewriter.getIndexAttr(1));
756   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
757       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
758 
759   return success();
760 }
761 
762 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
763     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
764   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
765   if (!padOp)
766     return failure();
767   // Only unit stride supported.
768   if (!sliceOp.hasUnitStride())
769     return failure();
770 
771   Operation *tiledPadOp = padOp.getTiledImplementation(
772       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
773       sliceOp.getMixedSizes());
774   // All shapes are static and the data source is actually used. Rewrite into
775   // pad_tensor(subtensor(x)).
776   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
777   return success();
778 }
779