1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 
47 //===----------------------------------------------------------------------===//
48 // Transformations exposed as rewrite patterns.
49 //===----------------------------------------------------------------------===//
50 // Marker used as attribute name in generated Linalg rewriting transformations.
51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
52     "__internal_linalg_transform__";
53 
54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
55     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
56     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement), matchByDefault(false) {}
58 
59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
60     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
61     Optional<StringAttr> replacement)
62     : filters(),
63       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
64       replacement(replacement), matchByDefault(false) {
65   if (f)
66     filters.push_back(f);
67 }
68 
69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
70     PatternRewriter &rewriter, Operation *op) const {
71   if (llvm::any_of(filters,
72                    [&](const FilterFunction &f) { return failed(f(op)); }))
73     return failure();
74 
75   auto attr = op->template getAttrOfType<StringAttr>(
76       LinalgTransforms::kLinalgTransformMarker);
77 
78   if (!attr) {
79     // 1. Has no filter case and matchDisjunction is empty.
80     if (matchDisjunction.empty() || matchByDefault)
81       return success();
82 
83     // 2. Has no filter but was expecting a filter.
84     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
85       diag << " does not have any filter from list: ";
86       interleaveComma(matchDisjunction, diag);
87     });
88   }
89 
90   // 4. Match explicit filter.
91   for (auto filter : matchDisjunction)
92     if (attr.getValue() == filter)
93       return success();
94 
95   // 5. Fail to match.
96   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
97     diag << " does not have any filter from list: ";
98     interleaveComma(matchDisjunction, diag);
99   });
100 }
101 
102 void mlir::linalg::LinalgTransformationFilter::
103     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
104                                       Operation *op) const {
105   if (replacement.has_value())
106     op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
107   else
108     op->removeAttr(
109         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
110 }
111 
112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
113     Operation *op) const {
114   if (!replacement)
115     return false;
116   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
117                   .dyn_cast<StringAttr>();
118   return attr && attr == *replacement;
119 }
120 
121 LinalgTilingOptions &
122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
123   assert(!tileSizeComputationFunction && "tile sizes already set");
124   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
125   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
126     OpBuilder::InsertionGuard guard(b);
127     b.setInsertionPointToStart(
128         &op->getParentOfType<func::FuncOp>().getBody().front());
129     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
130       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
131       return v;
132     }));
133   };
134   return *this;
135 }
136 
137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
138   assert(!tileSizeComputationFunction && "tile sizes already set");
139   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
140     SmallVector<Value, 4> tileSizes;
141     auto linalgOp = dyn_cast<LinalgOp>(op);
142     if (!linalgOp)
143       return tileSizes;
144     Location loc = linalgOp.getLoc();
145     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
146     AffineMap map = linalgOp.getShapesToLoopsMap();
147     if (!map)
148       return tileSizes;
149     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
150     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
151     // size 0).
152     for (Value shapeSize : shapeSizes)
153       tileSizes.push_back(getConstantIntValue(shapeSize)
154                               ? b.create<arith::ConstantIndexOp>(loc, 0)
155                               : b.create<arith::ConstantIndexOp>(loc, 1));
156     return tileSizes;
157   };
158   return *this;
159 }
160 
161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
163 /// Exit early and return the `opOperand` value if the shape dimensions that
164 /// match `paddingDimensions` have a static size and the nofold flag is not set.
165 /// Otherwise, try to pad the shape dimensions that match the iterator
166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
167 /// padding succeeds or failure otherwise.
168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
169     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
171     ArrayRef<bool> packPaddings) {
172   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
173   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
174 
175   // Collect the shape dimension that are a function of the `paddingDimensions`.
176   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
177   for (int64_t dim : paddingDimensions)
178     for (const auto &en : enumerate(indexingMap.getResults()))
179       if (en.value().isFunctionOfDim(dim))
180         shapeDimsToPad.insert(en.index());
181 
182   // Return the unpadded operand if padding to a static shape is not needed and
183   // if the nofold flag is not set.
184   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
185                     ? packPaddings[opOperand->getOperandNumber()]
186                     : false;
187   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
188     return ShapedType::isDynamic(shape[dim]);
189   });
190   if (!nofold && hasStaticShape)
191     return opOperand->get();
192 
193   // Fail if `paddingValues` specifies no padding value.
194   if (opOperand->getOperandNumber() >= paddingValues.size())
195     return failure();
196   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
197   Value paddingValue = b.create<arith::ConstantOp>(
198       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
199 
200   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
201   OpOperand *currOpOperand = opOperand;
202   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
203     OpResult result = currOpOperand->get().cast<OpResult>();
204     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
205   }
206 
207   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
208   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
209   if (!sliceOp)
210     return failure();
211 
212   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
213   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
214   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
215 
216   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
217   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
218   int64_t shapeIdx = 0;
219   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
220     // Skip dropped dimensions.
221     if (droppedDims.test(en.index()))
222       continue;
223     // Skip dimensions that do not require padding.
224     if (!shapeDimsToPad.contains(shapeIdx)) {
225       shapeIdx++;
226       continue;
227     }
228     // If the size is an attribute add it directly to `paddedShape`.
229     if (en.value().is<Attribute>()) {
230       paddedShape[shapeIdx++] =
231           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
232       continue;
233     }
234     // Otherwise, try to compute a constant upper bound for the size value.
235     FailureOr<int64_t> upperBound =
236         getConstantUpperBoundForIndex(en.value().get<Value>());
237     if (failed(upperBound)) {
238       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
239       return failure();
240     }
241     paddedShape[shapeIdx++] = *upperBound;
242   }
243   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
244          "expect the dynamic and static ranks to match");
245 
246   // Pad the operand to the bounding box defined by `paddedShape`.
247   auto paddedTensorType = RankedTensorType::get(
248       paddedShape, getElementTypeOrSelf(opOperand->get()));
249   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
250                                opOperand->get(), paddingValue, nofold);
251 }
252 
253 FailureOr<SmallVector<Value>>
254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255                           ArrayRef<int64_t> paddingDimensions,
256                           ArrayRef<Attribute> paddingValues,
257                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
258   Location loc = opToPad->getLoc();
259 
260   // TODO: there are cases where we may still want to pad to larger sizes.
261   assert(opToPad.hasTensorSemantics() &&
262          "expected operation to have tensor semantics");
263 
264   OpBuilder::InsertionGuard g(b);
265   // Set IP after op because we also take the dims of the original output.
266   b.setInsertionPointAfter(opToPad);
267   // Make a copy of the shaped operands and update it.
268   SmallVector<Value> newOperands;
269   newOperands.reserve(opToPad.getNumInputsAndOutputs());
270   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
271     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
272         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273     // Exit if `paddingDimensions` cannot be bounded statically.
274     if (failed(paddedOperand))
275       return failure();
276     newOperands.push_back(*paddedOperand);
277   }
278 
279   SmallVector<SmallVector<Value>> reifiedResultShapes;
280   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
281                  .reifyResultShapes(b, reifiedResultShapes)))
282     return failure();
283   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
284          "expected same number of results");
285 
286   // Clone `opToPad` to operate on the statically padded shapes.
287   auto resultTensorTypes =
288       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
289   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
290 
291   // Recover the slice out of the new static results. This keeps the original
292   // linalg op around because it uses the dims of the original results.
293   SmallVector<Value> paddedSubviewResults;
294   paddedSubviewResults.reserve(opToPad->getNumResults());
295   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
296     Value paddedResult = en.value();
297     int64_t resultNumber = en.index();
298     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
299     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
300     SmallVector<OpFoldResult> sizes;
301     for (Value v : reifiedResultShapes[resultNumber])
302       sizes.push_back(getAsOpFoldResult(v));
303     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
304     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
305         loc, paddedResult, offsets, sizes, strides));
306   }
307   return paddedSubviewResults;
308 }
309 
310 /// Try to peel a loop `op` and return the new result.
311 // TODO: Add support for scf.parallel and affine.for loops.
312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
313   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
314       .Case<scf::ForOp>([&](scf::ForOp forOp) {
315         scf::ForOp partialIteration;
316         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
317                                                       partialIteration)))
318           return partialIteration->getResults();
319         assert(!partialIteration && "expected that loop was not peeled");
320         return forOp->getResults();
321       })
322       .Default([&](Operation *op) { return op->getResults(); });
323 }
324 
325 /// Peel and canonicalize 'loops'.
326 void mlir::linalg::peelLoops(RewriterBase &rewriter,
327                              ArrayRef<scf::ForOp> loops) {
328   for (auto loopOp : loops) {
329     SmallVector<Value, 4> loopResults;
330     loopResults = peelLoop(rewriter, loopOp);
331   }
332 }
333 
334 /// Peel loops after tiling.
335 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
336                                      ArrayRef<int64_t> peeledLoops,
337                                      LinalgTilingLoopType loopType) {
338   for (int64_t loop : peeledLoops) {
339     assert(loop < static_cast<int64_t>(res.loops.size()) &&
340            "requested peeling of non-existing loop");
341     SmallVector<Value, 4> loopResults;
342     Operation *loopOp = res.loops[loop];
343     loopResults = peelLoop(rewriter, loopOp);
344 
345     // The result of the loop nest may change with peeling.
346     if (res.tensorResults.size() == loopOp->getNumResults() &&
347         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
348                    loopOp->getResults().begin()))
349       res.tensorResults = loopResults;
350   }
351 }
352 
353 /// Linalg tiling pattern.
354 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
355     MLIRContext *context, LinalgTilingOptions options,
356     LinalgTransformationFilter f, PatternBenefit benefit)
357     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
358       filter(std::move(f)), options(std::move(options)) {}
359 
360 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
361     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
362     LinalgTransformationFilter f, PatternBenefit benefit)
363     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
364       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
365 
366 FailureOr<TiledLinalgOp>
367 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
368     LinalgOp op, PatternRewriter &rewriter) const {
369   if (failed(filter.checkAndNotify(rewriter, op)))
370     return failure();
371 
372   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
373   if (failed(res))
374     return failure();
375 
376   // Clear filter to stop recursive pattern application.
377   // This must be done here to properly propagate to peeling branches.
378   filter.replaceLinalgTransformationFilter(rewriter, res->op);
379 
380   // Peel the loops of the TiledLinalgOp.
381   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
382 
383   if (res->tensorResults.empty())
384     rewriter.eraseOp(op);
385   else
386     rewriter.replaceOp(op, res->tensorResults);
387 
388   return res;
389 }
390 
391 /// Linalg padding pattern.
392 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
393     MLIRContext *context, LinalgPaddingOptions options,
394     LinalgTransformationFilter f, PatternBenefit benefit)
395     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
396       filter(std::move(f)), options(std::move(options)) {}
397 
398 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
399     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
400     LinalgTransformationFilter f, PatternBenefit benefit)
401     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
402       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
403 
404 FailureOr<LinalgOp>
405 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
406     LinalgOp linalgOp, PatternRewriter &rewriter) const {
407   if (!linalgOp.hasTensorSemantics())
408     return failure();
409   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
410     return failure();
411 
412   // Pad the operation.
413   LinalgOp paddedOp;
414   FailureOr<SmallVector<Value>> newResults =
415       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
416                         options.paddingValues, options.packPaddings, paddedOp);
417   if (failed(newResults))
418     return failure();
419 
420   // Hoist the padding.
421   for (const auto &en : enumerate(options.hoistPaddings)) {
422     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
423       break;
424     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
425     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
426     if (!padOp || en.value() == 0)
427       continue;
428 
429     // Fail hoisting if the operand shape is not fully static.
430     if (llvm::any_of(paddedOp.getShape(opOperand),
431                      [](int64_t size) { return ShapedType::isDynamic(size); }))
432       return failure();
433 
434     tensor::PadOp hoistedOp;
435     SmallVector<GenericOp> transposeOps;
436     SmallVector<int64_t> transposeVector =
437         en.index() < options.transposePaddings.size()
438             ? options.transposePaddings[en.index()]
439             : SmallVector<int64_t>{};
440 
441     FailureOr<Value> newResult = hoistPaddingOnTensors(
442         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
443     if (failed(newResult))
444       continue;
445     rewriter.replaceOp(padOp, *newResult);
446 
447     // Do not apply hoist padding to the newly introduced transpose operations.
448     for (GenericOp transposeOp : transposeOps)
449       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
450   }
451 
452   // Replace the original operation to pad.
453   rewriter.replaceOp(linalgOp, *newResults);
454   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
455 
456   return paddedOp;
457 }
458 
459 /// Linalg tile and fuse tensor ops pattern.
460 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
461     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
462                                       LinalgTilingAndFusionOptions options,
463                                       LinalgTransformationFilter f,
464                                       PatternBenefit benefit)
465     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
466       filter(std::move(f)), options(std::move(options)) {}
467 
468 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
469     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
470                                       LinalgTilingAndFusionOptions options,
471                                       LinalgTransformationFilter f,
472                                       PatternBenefit benefit)
473     : RewritePattern(opName, benefit, context), filter(std::move(f)),
474       options(std::move(options)) {}
475 
476 FailureOr<mlir::linalg::TileLoopNest>
477 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
478     Operation *op, PatternRewriter &rewriter) const {
479   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
480   if (!rootOp)
481     return failure();
482   if (failed(filter.checkAndNotify(rewriter, op)))
483     return failure();
484 
485   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
486   if (options.tileSizes.size() < rootOp.getNumLoops())
487     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
488 
489   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
490   if (!options.tileInterchange.empty() &&
491       options.tileInterchange.size() != options.tileSizes.size())
492     return rewriter.notifyMatchFailure(
493         op, "expect the number of tile sizes and interchange dims to match");
494 
495   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
496   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
497                                      options.tileSizes.begin() +
498                                          rootOp.getNumLoops());
499   SmallVector<int64_t> rootInterchange =
500       options.tileInterchange.empty()
501           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
502           : SmallVector<int64_t>(options.tileInterchange.begin(),
503                                  options.tileInterchange.begin() +
504                                      rootOp.getNumLoops());
505 
506   // Check `rootTileSizes` contains non-zero tile sizes.
507   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
508     return rewriter.notifyMatchFailure(
509         op, "expect at least one non-zero tile size");
510 
511   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
512   // It has to be a permutation since the tiling cannot tile the same loop
513   // dimension multiple times.
514   if (!isPermutation(rootInterchange))
515     return rewriter.notifyMatchFailure(
516         op, "expect the tile interchange permutes the root loops");
517 
518   // Tile `rootOp` and fuse its producers.
519   FailureOr<TileLoopNest> tileLoopNest =
520       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
521                                    rootInterchange, options.tileDistribution);
522   if (failed(tileLoopNest))
523     return rewriter.notifyMatchFailure(
524         op, "tileConsumerAndFuseProducers failed unexpectedly");
525 
526   // Replace all uses of the tiled loop operation.
527   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
528 
529   // Apply the filter if specified.
530   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
531     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
532   return tileLoopNest;
533 }
534 
535 /// Linalg generic interchange pattern.
536 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
537     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
538     LinalgTransformationFilter f, PatternBenefit benefit)
539     : OpRewritePattern(context, benefit), filter(std::move(f)),
540       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
541 
542 FailureOr<GenericOp>
543 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
544     GenericOp genericOp, PatternRewriter &rewriter) const {
545   if (failed(filter.checkAndNotify(rewriter, genericOp)))
546     return failure();
547 
548   FailureOr<GenericOp> transformedOp =
549       interchangeGenericOp(rewriter, genericOp, interchangeVector);
550   if (failed(transformedOp))
551     return failure();
552 
553   // New filter if specified.
554   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
555   return transformedOp;
556 }
557 
558 /// Linalg generalization pattern.
559 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
560     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
561     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
562       filter(std::move(f)) {}
563 
564 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
565     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
566     PatternBenefit benefit)
567     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
568       filter(f.addOpNameFilter(opName)) {}
569 
570 FailureOr<GenericOp>
571 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
572     LinalgOp linalgOp, PatternRewriter &rewriter) const {
573   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
574     return failure();
575   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
576   if (failed(genericOp))
577     return failure();
578   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
579   return genericOp;
580 }
581 
582 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
583     MLIRContext *context, LinalgTransformationFilter f,
584     LinalgPeelOptions options, PatternBenefit benefit)
585     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
586       filter(std::move(f)), options(std::move(options)) {}
587 
588 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
589     StringRef opName, MLIRContext *context, LinalgPeelOptions options,
590     LinalgTransformationFilter f, PatternBenefit benefit)
591     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
592       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
593 
594 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
595     LinalgOp linalgOp, PatternRewriter &rewriter) const {
596   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
597     return failure();
598 
599   // Increase marker counter even if peeling doesn't happen for this op.
600   filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
601 
602   if (!options.loopsToPeelComputationFunction)
603     return failure();
604 
605   SmallVector<scf::ForOp, 4> loopsToPeel;
606   options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
607   peelLoops(rewriter, loopsToPeel);
608   return success();
609 }
610 
611 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
612     MLIRContext *context, LinalgTransformationFilter f,
613     LinalgVectorizationOptions options, PatternBenefit benefit)
614     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
615       filter(std::move(f)) {}
616 
617 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
618     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
619     LinalgTransformationFilter f, PatternBenefit benefit)
620     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
621       filter(f.addOpNameFilter(opName)) {}
622 
623 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
624     LinalgOp linalgOp, PatternRewriter &rewriter) const {
625   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
626     return failure();
627   return vectorize(rewriter, linalgOp);
628 }
629 
630 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
631     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
632   return vectorizeCopy(rewriter, copyOp);
633 }
634 
635 LogicalResult mlir::linalg::applyStagedPatterns(
636     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
637     const FrozenRewritePatternSet &stage2Patterns,
638     function_ref<LogicalResult(Operation *)> stage3Lambda) {
639   unsigned iteration = 0;
640   (void)iteration;
641   for (const auto &patterns : stage1Patterns) {
642     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
643                       << *op);
644     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
645       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
646       return failure();
647     }
648     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
649                       << *op);
650     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
651       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
652       return failure();
653     }
654     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
655                       << *op);
656     if (stage3Lambda) {
657       if (failed(stage3Lambda(op)))
658         return failure();
659       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
660                         << *op);
661     }
662   }
663   return success();
664 }
665 
666 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
667   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
668 }
669 
670 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
671 /// initialize with pad_val) and GenericOp (to copy contents).
672 LogicalResult
673 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
674                                             PatternRewriter &rewriter) const {
675 
676   auto inputShapedType = padOp.getSource().getType().cast<ShapedType>();
677   auto resultShapedType = padOp.getResult().getType().cast<ShapedType>();
678 
679   // Bail on non-static shapes.
680   if (!inputShapedType.hasStaticShape())
681     return failure();
682   if (!resultShapedType.hasStaticShape())
683     return failure();
684 
685   // Only support padding with a constant for now, i.e. either:
686   //   1. A BBarg from a different block.
687   //   2. A value defined outside of the current block.
688   Block &block = padOp.getRegion().front();
689   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
690   Value padValue = yieldOp.getValue();
691   Operation *definingOp = padValue.getDefiningOp();
692   if (definingOp && definingOp->getBlock() == &block)
693     return failure();
694   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
695     return failure();
696 
697   // Create tensor with the padded shape
698   Location loc = padOp.getLoc();
699   SmallVector<Value> indices(resultShapedType.getRank(),
700                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
701   Value initTensor = rewriter.create<InitTensorOp>(
702       loc, resultShapedType.getShape(), resultShapedType.getElementType());
703 
704   // Initialize tensor with the pad value
705   Value tmpTensor = rewriter
706                         .create<linalg::FillOp>(loc, ValueRange{padValue},
707                                                 ValueRange{initTensor})
708                         .result();
709 
710   // Copy original contents into new tensor
711   // Uses linalg.generic, but could be done with tensor.insert_slice
712   SmallVector<AffineExpr, 4> outputExprs;
713   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
714     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
715                           padOp.getStaticLow()[i].cast<IntegerAttr>().getInt());
716   }
717 
718   SmallVector<AffineMap, 2> transferMaps = {
719       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
720       AffineMap::get(resultShapedType.getRank(),
721                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
722 
723   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
724       padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps,
725       getNParallelLoopsAttrs(resultShapedType.getRank()),
726       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
727         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
728       });
729 
730   return success();
731 }
732 
733 /// Filling `dest` using FillOp constant padding value if possible.
734 /// Otherwise, generate a tensor::GenerateOp.
735 Value GeneralizePadOpPattern::createFillOrGenerateOp(
736     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
737     const SmallVector<Value> &dynSizes) const {
738   auto padValue = padOp.getConstantPaddingValue();
739   if (padValue)
740     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
741 
742   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
743   auto generateOp = rewriter.create<tensor::GenerateOp>(
744       padOp.getLoc(), padOp.getResultType(), dynSizes);
745   // Copy region to new op.
746   BlockAndValueMapping bvm;
747   padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
748   return generateOp;
749 }
750 
751 LogicalResult
752 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
753                                         PatternRewriter &rewriter) const {
754   // Given an OpFoldResult, return an index-typed value.
755   auto getIdxValue = [&](OpFoldResult ofr) {
756     if (auto val = ofr.dyn_cast<Value>())
757       return val;
758     return rewriter
759         .create<arith::ConstantIndexOp>(
760             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
761         .getResult();
762   };
763 
764   auto resultType = padOp.getResultType();
765   // Compute size of InitTensorOp. Any combination of static/dynamic is
766   // supported.
767   SmallVector<Value> dynSizes;
768   SmallVector<int64_t> staticSizes;
769   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
770     if (resultType.isDynamicDim(dim)) {
771       auto srcSize = rewriter.createOrFold<tensor::DimOp>(
772           padOp.getLoc(), padOp.getSource(), dim);
773       // Add low and high padding value.
774       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
775           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
776       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
777           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
778       dynSizes.push_back(plusHigh);
779     }
780     staticSizes.push_back(resultType.getDimSize(dim));
781   }
782 
783   // Init tensor and fill it with padding.
784   Value init = rewriter.create<InitTensorOp>(
785       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
786   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
787 
788   // Try optimize the copy of source.
789   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
790     return success();
791 
792   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
793   // for copying the PadOp source.
794   auto sourceType = padOp.getSourceType();
795   // Compute size of source of tensor::PadOp.
796   SmallVector<OpFoldResult> srcSizes;
797   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
798     if (sourceType.isDynamicDim(dim)) {
799       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
800           padOp.getLoc(), padOp.getSource(), dim));
801     } else {
802       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
803     }
804   }
805   // Strides of InsertSliceOp are all 1.
806   SmallVector<OpFoldResult> strides(sourceType.getRank(),
807                                     rewriter.getIndexAttr(1));
808   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
809       padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
810       strides);
811 
812   return success();
813 }
814 
815 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
816     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
817   if (!sliceOp.hasUnitStride())
818     return failure();
819 
820   auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
821   if (!padOp)
822     return failure();
823 
824   bool zeroSliceGuard = true;
825   if (controlFn) {
826     if (Optional<bool> control = controlFn(sliceOp))
827       zeroSliceGuard = *control;
828     else
829       return failure();
830   }
831 
832   Operation *tiledPadOp =
833       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
834                                sliceOp.getMixedSizes(), zeroSliceGuard);
835   // All shapes are static and the data source is actually used. Rewrite into
836   // pad(extract_slice(x)).
837   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
838   return success();
839 }
840 
841 // The following are patterns for downscaling convolution ops with size-1
842 // window dimensions.
843 //
844 // Note that we'd eventually want to write such transformations in a generic
845 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
846 // and then turning back to named ops. But for now it's fine to have a few
847 // patterns matching special ops to get started.
848 
849 FailureOr<Conv1DNwcWcfOp>
850 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
851     linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
852   if (failed(filter.checkAndNotify(rewriter, convOp)))
853     return failure();
854   if (convOp.hasBufferSemantics())
855     return failure(); // To be implemented.
856 
857   Value input = convOp.inputs().front();
858   Value kernel = convOp.inputs().back();
859   Value output = convOp.outputs().front();
860 
861   auto inputType = input.getType().dyn_cast<RankedTensorType>();
862   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
863   auto outputType = output.getType().dyn_cast<RankedTensorType>();
864 
865   auto kernelShape = kernelType.getShape();
866   auto outputShape = outputType.getShape();
867 
868   // Only handle the case where at least one of the window dimensions is
869   // of size 1. Other cases can rely on tiling to reduce to such cases.
870   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
871   int64_t ohSize = outputShape[1], owSize = outputShape[2];
872   bool removeH = (khSize == 1 && ohSize == 1);
873   bool removeW = (kwSize == 1 && owSize == 1);
874   if (!removeH && !removeW)
875     return failure();
876 
877   // Get new shapes and types for all operands by removing the size-1
878   // dimension.
879   using RTTBuilder = RankedTensorType::Builder;
880   RankedTensorType newInputType =
881       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
882   RankedTensorType newKernelType =
883       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
884   RankedTensorType newOutputType =
885       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
886 
887   // Rank-reduce operands.
888   Location loc = convOp.getLoc();
889   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
890       rewriter, loc, input, newInputType);
891   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
892       rewriter, loc, kernel, newKernelType);
893   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
894       rewriter, loc, output, newOutputType);
895 
896   // Rank-reduce strides and dilations too.
897   // TODO: dropDim 1-liner helper.
898   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
899   strides.erase(strides.begin() + (removeH ? 0 : 1));
900   auto stridesAttr = rewriter.getI64VectorAttr(strides);
901 
902   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
903   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
904   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
905 
906   auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
907       loc, newOutputType, ValueRange{newInput, newKernel},
908       ValueRange{newOutput}, stridesAttr, dilationsAttr);
909 
910   // Insert back.
911   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
912       rewriter, loc, conv1DOp.getResult(0), output);
913   rewriter.replaceOp(convOp, inserted);
914 
915   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
916   return conv1DOp;
917 }
918 
919 FailureOr<DepthwiseConv1DNwcWcOp>
920 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
921     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
922   if (failed(filter.checkAndNotify(rewriter, convOp)))
923     return failure();
924   if (convOp.hasBufferSemantics())
925     return failure(); // To be implemented.
926 
927   Value input = convOp.inputs().front();
928   Value kernel = convOp.inputs().back();
929   Value output = convOp.outputs().front();
930 
931   auto inputType = input.getType().dyn_cast<RankedTensorType>();
932   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
933   auto outputType = output.getType().dyn_cast<RankedTensorType>();
934 
935   auto kernelShape = kernelType.getShape();
936   auto outputShape = outputType.getShape();
937 
938   // Only handle the case where at least one of the window dimensions is
939   // of size 1. Other cases can rely on tiling to reduce to such cases.
940   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
941   int64_t ohSize = outputShape[1], owSize = outputShape[2];
942   bool removeH = (khSize == 1 && ohSize == 1);
943   bool removeW = (kwSize == 1 && owSize == 1);
944   if (!removeH && !removeW)
945     return failure();
946 
947   // Get new shapes and types for all operands by removing the size-1
948   // dimension.
949   using RTTBuilder = RankedTensorType::Builder;
950   RankedTensorType newInputType =
951       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
952   RankedTensorType newKernelType =
953       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
954   RankedTensorType newOutputType =
955       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
956 
957   // Rank-reduce operands.
958   Location loc = convOp.getLoc();
959   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
960       rewriter, loc, input, newInputType);
961   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
962       rewriter, loc, kernel, newKernelType);
963   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
964       rewriter, loc, output, newOutputType);
965 
966   // Rank-reduce strides and dilations too.
967   // TODO: dropDim 1-liner helper.
968   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
969   strides.erase(strides.begin() + (removeH ? 0 : 1));
970   auto stridesAttr = rewriter.getI64VectorAttr(strides);
971 
972   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
973   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
974   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
975 
976   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
977       loc, newOutputType, ValueRange{newInput, newKernel},
978       ValueRange{newOutput}, stridesAttr, dilationsAttr);
979 
980   // Insert back.
981   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
982       rewriter, loc, conv1DOp.getResult(0), output);
983   rewriter.replaceOp(convOp, inserted);
984 
985   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
986   return conv1DOp;
987 }
988 
989 void linalg::populateDecomposeConvolutionPatterns(
990     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
991     PatternBenefit benefit) {
992   patterns.add<DownscaleSizeOneWindowed2DConvolution,
993                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
994                                                   benefit);
995 }
996