1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/SCF/Transforms.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include <type_traits>
37 #include <utility>
38 
39 #define DEBUG_TYPE "linalg-transforms"
40 
41 using namespace mlir;
42 using namespace mlir::linalg;
43 
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
45 
46 //===----------------------------------------------------------------------===//
47 // Transformations exposed as rewrite patterns.
48 //===----------------------------------------------------------------------===//
49 // Marker used as attribute name in generated Linalg rewriting transformations.
50 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
51     "__internal_linalg_transform__";
52 
53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
54     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
55     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
56       replacement(replacement), matchByDefault(false) {}
57 
58 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
59     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
60     Optional<StringAttr> replacement)
61     : filters(),
62       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
63       replacement(replacement), matchByDefault(false) {
64   if (f)
65     filters.push_back(f);
66 }
67 
68 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
69     PatternRewriter &rewriter, Operation *op) const {
70   if (llvm::any_of(filters,
71                    [&](const FilterFunction &f) { return failed(f(op)); }))
72     return failure();
73 
74   auto attr = op->template getAttrOfType<StringAttr>(
75       LinalgTransforms::kLinalgTransformMarker);
76 
77   if (!attr) {
78     // 1. Has no filter case and matchDisjunction is empty.
79     if (matchDisjunction.empty() || matchByDefault)
80       return success();
81 
82     // 2. Has no filter but was expecting a filter.
83     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
84       diag << " does not have any filter from list: ";
85       interleaveComma(matchDisjunction, diag);
86     });
87   }
88 
89   // 4. Match explicit filter.
90   for (auto filter : matchDisjunction)
91     if (attr.getValue() == filter)
92       return success();
93 
94   // 5. Fail to match.
95   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
96     diag << " does not have any filter from list: ";
97     interleaveComma(matchDisjunction, diag);
98   });
99 }
100 
101 void mlir::linalg::LinalgTransformationFilter::
102     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
103                                       Operation *op) const {
104   if (replacement.hasValue())
105     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
106                 replacement.getValue());
107   else
108     op->removeAttr(
109         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
110 }
111 
112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
113     Operation *op) const {
114   if (!replacement)
115     return false;
116   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
117                   .dyn_cast<StringAttr>();
118   return attr && attr == replacement.getValue();
119 }
120 
121 LinalgTilingOptions &
122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
123   assert(!tileSizeComputationFunction && "tile sizes already set");
124   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
125   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
126     OpBuilder::InsertionGuard guard(b);
127     b.setInsertionPointToStart(
128         &op->getParentOfType<func::FuncOp>().getBody().front());
129     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
130       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
131       return v;
132     }));
133   };
134   return *this;
135 }
136 
137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
138   assert(!tileSizeComputationFunction && "tile sizes already set");
139   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
140     SmallVector<Value, 4> tileSizes;
141     auto linalgOp = dyn_cast<LinalgOp>(op);
142     if (!linalgOp)
143       return tileSizes;
144     Location loc = linalgOp.getLoc();
145     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
146     AffineMap map = linalgOp.getShapesToLoopsMap();
147     if (!map)
148       return tileSizes;
149     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
150     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
151     // size 0).
152     for (Value shapeSize : shapeSizes)
153       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
154                               ? b.create<arith::ConstantIndexOp>(loc, 0)
155                               : b.create<arith::ConstantIndexOp>(loc, 1));
156     return tileSizes;
157   };
158   return *this;
159 }
160 
161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
163 /// Exit early and return the `opOperand` value if the shape dimensions that
164 /// match `paddingDimensions` have a static size and the nofold flag is not set.
165 /// Otherwise, try to pad the shape dimensions that match the iterator
166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
167 /// padding succeeds or failure otherwise.
168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
169     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
171     ArrayRef<bool> packPaddings) {
172   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
173   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
174 
175   // Collect the shape dimension that are a function of the `paddingDimensions`.
176   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
177   for (int64_t dim : paddingDimensions)
178     for (const auto &en : enumerate(indexingMap.getResults()))
179       if (en.value().isFunctionOfDim(dim))
180         shapeDimsToPad.insert(en.index());
181 
182   // Return the unpadded operand if padding to a static shape is not needed and
183   // if the nofold flag is not set.
184   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
185                     ? packPaddings[opOperand->getOperandNumber()]
186                     : false;
187   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
188     return ShapedType::isDynamic(shape[dim]);
189   });
190   if (!nofold && hasStaticShape)
191     return opOperand->get();
192 
193   // Fail if `paddingValues` specifies no padding value.
194   if (opOperand->getOperandNumber() >= paddingValues.size())
195     return failure();
196   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
197   Value paddingValue = b.create<arith::ConstantOp>(
198       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
199 
200   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
201   OpOperand *currOpOperand = opOperand;
202   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
203     OpResult result = currOpOperand->get().cast<OpResult>();
204     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
205   }
206 
207   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
208   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
209   if (!sliceOp)
210     return failure();
211 
212   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
213   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
214   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
215 
216   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
217   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
218   int64_t shapeIdx = 0;
219   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
220     // Skip dropped dimensions.
221     if (droppedDims.test(en.index()))
222       continue;
223     // Skip dimensions that do not require padding.
224     if (!shapeDimsToPad.contains(shapeIdx)) {
225       shapeIdx++;
226       continue;
227     }
228     // If the size is an attribute add it directly to `paddedShape`.
229     if (en.value().is<Attribute>()) {
230       paddedShape[shapeIdx++] =
231           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
232       continue;
233     }
234     // Otherwise, try to compute a constant upper bound for the size value.
235     FailureOr<int64_t> upperBound =
236         getConstantUpperBoundForIndex(en.value().get<Value>());
237     if (failed(upperBound)) {
238       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
239       return failure();
240     }
241     paddedShape[shapeIdx++] = upperBound.getValue();
242   }
243   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
244          "expect the dynamic and static ranks to match");
245 
246   // Pad the operand to the bounding box defined by `paddedShape`.
247   auto paddedTensorType = RankedTensorType::get(
248       paddedShape, getElementTypeOrSelf(opOperand->get()));
249   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
250                                opOperand->get(), paddingValue, nofold);
251 }
252 
253 FailureOr<SmallVector<Value>>
254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255                           ArrayRef<int64_t> paddingDimensions,
256                           ArrayRef<Attribute> paddingValues,
257                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
258   Location loc = opToPad->getLoc();
259 
260   // TODO: there are cases where we may still want to pad to larger sizes.
261   assert(opToPad.hasTensorSemantics() &&
262          "expected operation to have tensor semantics");
263 
264   OpBuilder::InsertionGuard g(b);
265   // Set IP after op because we also take the dims of the original output.
266   b.setInsertionPointAfter(opToPad);
267   // Make a copy of the shaped operands and update it.
268   SmallVector<Value> newOperands;
269   newOperands.reserve(opToPad.getNumInputsAndOutputs());
270   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
271     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
272         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273     // Exit if `paddingDimensions` cannot be bounded statically.
274     if (failed(paddedOperand))
275       return failure();
276     newOperands.push_back(*paddedOperand);
277   }
278 
279   SmallVector<SmallVector<Value>> reifiedResultShapes;
280   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
281                  .reifyResultShapes(b, reifiedResultShapes)))
282     return failure();
283   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
284          "expected same number of results");
285 
286   // Clone `opToPad` to operate on the statically padded shapes.
287   auto resultTensorTypes =
288       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
289   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
290 
291   // Recover the slice out of the new static results. This keeps the original
292   // linalg op around because it uses the dims of the original results.
293   SmallVector<Value> paddedSubviewResults;
294   paddedSubviewResults.reserve(opToPad->getNumResults());
295   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
296     Value paddedResult = en.value();
297     int64_t resultNumber = en.index();
298     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
299     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
300     SmallVector<OpFoldResult> sizes;
301     for (Value v : reifiedResultShapes[resultNumber])
302       sizes.push_back(getAsOpFoldResult(v));
303     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
304     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
305         loc, paddedResult, offsets, sizes, strides));
306   }
307   return paddedSubviewResults;
308 }
309 
310 /// Try to peel a loop `op` and return the new result.
311 // TODO: Add support for scf.parallel and affine.for loops.
312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
313   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
314       .Case<scf::ForOp>([&](scf::ForOp forOp) {
315         scf::ForOp partialIteration;
316         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
317                                                       partialIteration)))
318           return partialIteration->getResults();
319         assert(!partialIteration && "expected that loop was not peeled");
320         return forOp->getResults();
321       })
322       .Default([&](Operation *op) { return op->getResults(); });
323 }
324 
325 /// Peel loops after tiling.
326 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
327                                      ArrayRef<int64_t> peeledLoops,
328                                      LinalgTilingLoopType loopType) {
329   for (int64_t loop : peeledLoops) {
330     assert(loop < static_cast<int64_t>(res.loops.size()) &&
331            "requested peeling of non-existing loop");
332     SmallVector<Value, 4> loopResults;
333     Operation *loopOp = res.loops[loop];
334     loopResults = peelLoop(rewriter, loopOp);
335 
336     // The result of the loop nest may change with peeling.
337     if (res.tensorResults.size() == loopOp->getNumResults() &&
338         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
339                    loopOp->getResults().begin()))
340       res.tensorResults = loopResults;
341   }
342 }
343 
344 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
345   if (tiledOp.loops.empty())
346     return tiledOp.op.getOperation()->getResults();
347   return tiledOp.loops.front()->getResults();
348 }
349 
350 static ValueRange
351 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
352   if (tiledAndFusedOp.fusedLoops.empty())
353     return tiledAndFusedOp.op.getOperation()->getResults();
354   return tiledAndFusedOp.fusedLoops.front()->getResults();
355 }
356 
357 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
358     StringRef opName, MLIRContext *context,
359     const LinalgDependenceGraph &dependenceGraph,
360     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
361     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
362     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
363     : RewritePattern(opName, benefit, context, {}),
364       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
365       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
366       fusedOpMarker(std::move(fusedOpMarker)),
367       originalOpMarker(std::move(originalOpMarker)) {}
368 
369 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
370     Operation *op, PatternRewriter &rewriter) const {
371   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
372   // TODO: remove hasIndexSemantics check once index ops are supported.
373   if (!linalgOp || linalgOp.hasIndexSemantics())
374     return failure();
375   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
376     return failure();
377 
378   DenseSet<Operation *> producers;
379   producers.insert(linalgOp);
380   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
381     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
382     // When looking at dependences into, indexingOp is always OpOperand. We
383     // could assert, but continue if this is not the case.
384     if (!operandNumber)
385       continue;
386     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
387       continue;
388     if (isa<LinalgOp>(dependence.getDependentOp()))
389       producers.insert(dependence.getDependentOp());
390   }
391 
392   SmallVector<LinalgOp, 1> fusionOps;
393   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
394        ++it) {
395     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
396     if (producerLinalgOp && producers.count(producerLinalgOp))
397       fusionOps.push_back(producerLinalgOp);
398   }
399   fusionOps.push_back(linalgOp);
400 
401   SmallVector<Value, 4> tileSizes =
402       tilingOptions.tileSizeComputationFunction(rewriter, op);
403   LinalgTilingOptions instanceTilingOptions = tilingOptions;
404   instanceTilingOptions.setTileSizes(tileSizes);
405   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
406       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
407   if (!tiledAndFusedOps)
408     return failure();
409 
410   // Tile the unfused loops;
411   SmallVector<Value, 4> unfusedLoopTileSizes;
412   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
413   for (const auto &tileSize : enumerate(tileSizes)) {
414     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
415       unfusedLoopTileSizes.push_back(zero);
416     else
417       unfusedLoopTileSizes.push_back(tileSize.value());
418   }
419   // Tile the loop only if there is a non-zero tile size.
420   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
421     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
422   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
423         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
424           return cst.value() != 0;
425         return true;
426       })) {
427     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
428     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
429     FailureOr<TiledLinalgOp> unfusedTiledOp =
430         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
431     if (failed(unfusedTiledOp))
432       return failure();
433     rewriter.replaceOp(tiledAndFusedOps->op,
434                        getTiledOpResult(unfusedTiledOp.getValue()));
435     tiledAndFusedOps->op = unfusedTiledOp->op;
436   }
437   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
438 
439   filter.replaceLinalgTransformationFilter(rewriter,
440                                            tiledAndFusedOps->op.getOperation());
441   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
442     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
443                                                     fusedOp.getOperation());
444   }
445   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
446     originalOpMarker.replaceLinalgTransformationFilter(
447         rewriter, origProducerOp.getOperation());
448   }
449   rewriter.updateRootInPlace(op, [&]() {
450     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
451   });
452   return success();
453 }
454 
455 /// Linalg tiling pattern.
456 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
457     MLIRContext *context, LinalgTilingOptions options,
458     LinalgTransformationFilter f, PatternBenefit benefit)
459     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
460       filter(std::move(f)), options(std::move(options)) {}
461 
462 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
463     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
464     LinalgTransformationFilter f, PatternBenefit benefit)
465     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
466       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
467 
468 FailureOr<TiledLinalgOp>
469 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
470     LinalgOp op, PatternRewriter &rewriter) const {
471   if (failed(filter.checkAndNotify(rewriter, op)))
472     return failure();
473 
474   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
475   if (failed(res))
476     return failure();
477 
478   // Clear filter to stop recursive pattern application.
479   // This must be done here to properly propagate to peeling branches.
480   filter.replaceLinalgTransformationFilter(rewriter, res->op);
481 
482   // Peel the loops of the TiledLinalgOp.
483   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
484 
485   if (res->tensorResults.empty())
486     rewriter.eraseOp(op);
487   else
488     rewriter.replaceOp(op, res->tensorResults);
489 
490   return res;
491 }
492 
493 /// Linalg padding pattern.
494 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
495     MLIRContext *context, LinalgPaddingOptions options,
496     LinalgTransformationFilter f, PatternBenefit benefit)
497     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
498       filter(std::move(f)), options(std::move(options)) {}
499 
500 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
501     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
502     LinalgTransformationFilter f, PatternBenefit benefit)
503     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
504       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
505 
506 FailureOr<LinalgOp>
507 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
508     LinalgOp linalgOp, PatternRewriter &rewriter) const {
509   if (!linalgOp.hasTensorSemantics())
510     return failure();
511   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
512     return failure();
513 
514   // Pad the operation.
515   LinalgOp paddedOp;
516   FailureOr<SmallVector<Value>> newResults =
517       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
518                         options.paddingValues, options.packPaddings, paddedOp);
519   if (failed(newResults))
520     return failure();
521 
522   // Hoist the padding.
523   for (const auto &en : enumerate(options.hoistPaddings)) {
524     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
525       break;
526     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
527     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
528     if (!padOp || en.value() == 0)
529       continue;
530 
531     // Fail hoisting if the operand shape is not fully static.
532     if (llvm::any_of(paddedOp.getShape(opOperand),
533                      [](int64_t size) { return ShapedType::isDynamic(size); }))
534       return failure();
535 
536     tensor::PadOp hoistedOp;
537     SmallVector<GenericOp> transposeOps;
538     SmallVector<int64_t> transposeVector =
539         en.index() < options.transposePaddings.size()
540             ? options.transposePaddings[en.index()]
541             : SmallVector<int64_t>{};
542 
543     FailureOr<Value> newResult = hoistPaddingOnTensors(
544         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
545     if (failed(newResult))
546       continue;
547     rewriter.replaceOp(padOp, newResult.getValue());
548 
549     // Do not apply hoist padding to the newly introduced transpose operations.
550     for (GenericOp transposeOp : transposeOps)
551       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
552   }
553 
554   // Replace the original operation to pad.
555   rewriter.replaceOp(linalgOp, newResults.getValue());
556   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
557 
558   return paddedOp;
559 }
560 
561 /// Linalg tile and fuse tensor ops pattern.
562 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
563     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
564                                       LinalgTilingAndFusionOptions options,
565                                       LinalgTransformationFilter f,
566                                       PatternBenefit benefit)
567     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
568       filter(std::move(f)), options(std::move(options)) {}
569 
570 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
571     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
572                                       LinalgTilingAndFusionOptions options,
573                                       LinalgTransformationFilter f,
574                                       PatternBenefit benefit)
575     : RewritePattern(opName, benefit, context), filter(std::move(f)),
576       options(std::move(options)) {}
577 
578 FailureOr<mlir::linalg::TileLoopNest>
579 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
580     Operation *op, PatternRewriter &rewriter) const {
581   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
582   if (!rootOp)
583     return failure();
584   if (failed(filter.checkAndNotify(rewriter, op)))
585     return failure();
586 
587   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
588   if (options.tileSizes.size() < rootOp.getNumLoops())
589     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
590 
591   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
592   if (!options.tileInterchange.empty() &&
593       options.tileInterchange.size() != options.tileSizes.size())
594     return rewriter.notifyMatchFailure(
595         op, "expect the number of tile sizes and interchange dims to match");
596 
597   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
598   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
599                                      options.tileSizes.begin() +
600                                          rootOp.getNumLoops());
601   SmallVector<int64_t> rootInterchange =
602       options.tileInterchange.empty()
603           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
604           : SmallVector<int64_t>(options.tileInterchange.begin(),
605                                  options.tileInterchange.begin() +
606                                      rootOp.getNumLoops());
607 
608   // Check `rootTileSizes` contains non-zero tile sizes.
609   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
610     return rewriter.notifyMatchFailure(
611         op, "expect at least one non-zero tile size");
612 
613   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
614   // It has to be a permutation since the tiling cannot tile the same loop
615   // dimension multiple times.
616   if (!isPermutation(rootInterchange))
617     return rewriter.notifyMatchFailure(
618         op, "expect the tile interchange permutes the root loops");
619 
620   // Tile `rootOp` and fuse its producers.
621   FailureOr<TileLoopNest> tileLoopNest =
622       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
623                                    rootInterchange, options.tileDistribution);
624   if (failed(tileLoopNest))
625     return rewriter.notifyMatchFailure(
626         op, "tileConsumerAndFuseProducers failed unexpectedly");
627 
628   // Replace all uses of the tiled loop operation.
629   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
630 
631   // Apply the filter if specified.
632   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
633     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
634   return tileLoopNest;
635 }
636 
637 /// Linalg generic interchange pattern.
638 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
639     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
640     LinalgTransformationFilter f, PatternBenefit benefit)
641     : OpRewritePattern(context, benefit), filter(std::move(f)),
642       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
643 
644 FailureOr<GenericOp>
645 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
646     GenericOp genericOp, PatternRewriter &rewriter) const {
647   if (failed(filter.checkAndNotify(rewriter, genericOp)))
648     return failure();
649 
650   FailureOr<GenericOp> transformedOp =
651       interchangeGenericOp(rewriter, genericOp, interchangeVector);
652   if (failed(transformedOp))
653     return failure();
654 
655   // New filter if specified.
656   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
657   return transformedOp;
658 }
659 
660 /// Linalg generalization pattern.
661 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
662     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
663     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
664       filter(std::move(f)) {}
665 
666 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
667     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
668     PatternBenefit benefit)
669     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
670       filter(f.addOpNameFilter(opName)) {}
671 
672 FailureOr<GenericOp>
673 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
674     LinalgOp linalgOp, PatternRewriter &rewriter) const {
675   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
676     return failure();
677   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
678   if (failed(genericOp))
679     return failure();
680   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
681   return genericOp;
682 }
683 
684 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
685     MLIRContext *context, LinalgTransformationFilter f,
686     LinalgPromotionOptions options, PatternBenefit benefit)
687     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
688       filter(std::move(f)), options(std::move(options)) {}
689 
690 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
691     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
692     LinalgTransformationFilter f, PatternBenefit benefit)
693     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
694       options(std::move(options)) {}
695 
696 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
697     Operation *op, PatternRewriter &rewriter) const {
698   if (failed(filter.checkAndNotify(rewriter, op)))
699     return failure();
700   if (failed(promoteSubviewsPrecondition(op, options)))
701     return failure();
702 
703   // TODO: We cannot use root update here. This pattern is creating other ops,
704   // so if the promotion fails, those need to be cleaned up, which doesnt seem
705   // to be happening here. So to fail properly, we should be cloning the op and
706   // deleting the previous op. This needs more investigation.
707   rewriter.startRootUpdate(op);
708   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
709   if (!promotedOp) {
710     rewriter.cancelRootUpdate(op);
711     return op->emitError("subview promotion failed");
712   }
713   rewriter.finalizeRootUpdate(op);
714   filter.replaceLinalgTransformationFilter(rewriter, op);
715   return success();
716 }
717 
718 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
719     MLIRContext *context, LinalgTransformationFilter f,
720     LinalgVectorizationOptions options, PatternBenefit benefit)
721     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
722       filter(std::move(f)) {}
723 
724 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
725     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
726     LinalgTransformationFilter f, PatternBenefit benefit)
727     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
728       filter(f.addOpNameFilter(opName)) {}
729 
730 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
731     LinalgOp linalgOp, PatternRewriter &rewriter) const {
732   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
733     return failure();
734   return vectorize(rewriter, linalgOp);
735 }
736 
737 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
738     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
739   return vectorizeCopy(rewriter, copyOp);
740 }
741 
742 LogicalResult mlir::linalg::applyStagedPatterns(
743     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
744     const FrozenRewritePatternSet &stage2Patterns,
745     function_ref<LogicalResult(Operation *)> stage3Lambda) {
746   unsigned iteration = 0;
747   (void)iteration;
748   for (const auto &patterns : stage1Patterns) {
749     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
750                       << *op);
751     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
752       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
753       return failure();
754     }
755     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
756                       << *op);
757     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
758       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
759       return failure();
760     }
761     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
762                       << *op);
763     if (stage3Lambda) {
764       if (failed(stage3Lambda(op)))
765         return failure();
766       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
767                         << *op);
768     }
769   }
770   return success();
771 }
772 
773 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
774   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
775 }
776 
777 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
778 /// initialize with pad_val) and GenericOp (to copy contents).
779 LogicalResult
780 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
781                                             PatternRewriter &rewriter) const {
782 
783   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
784   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
785 
786   // Bail on non-static shapes.
787   if (!inputShapedType.hasStaticShape())
788     return failure();
789   if (!resultShapedType.hasStaticShape())
790     return failure();
791 
792   // Only support padding with a constant for now, i.e. either:
793   //   1. A BBarg from a different block.
794   //   2. A value defined outside of the current block.
795   Block &block = padOp.region().front();
796   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
797   Value padValue = yieldOp.value();
798   Operation *definingOp = padValue.getDefiningOp();
799   if (definingOp && definingOp->getBlock() == &block)
800     return failure();
801   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
802     return failure();
803 
804   // Create tensor with the padded shape
805   Location loc = padOp.getLoc();
806   SmallVector<Value> indices(resultShapedType.getRank(),
807                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
808   Value initTensor = rewriter.create<InitTensorOp>(
809       loc, resultShapedType.getShape(), resultShapedType.getElementType());
810 
811   // Initialize tensor with the pad value
812   Value tmpTensor = rewriter
813                         .create<linalg::FillOp>(loc, ValueRange{padValue},
814                                                 ValueRange{initTensor})
815                         .result();
816 
817   // Copy original contents into new tensor
818   // Uses linalg.generic, but could be done with tensor.insert_slice
819   SmallVector<AffineExpr, 4> outputExprs;
820   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
821     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
822                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
823   }
824 
825   SmallVector<AffineMap, 2> transferMaps = {
826       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
827       AffineMap::get(resultShapedType.getRank(),
828                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
829 
830   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
831       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
832       getNParallelLoopsAttrs(resultShapedType.getRank()),
833       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
834         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
835       });
836 
837   return success();
838 }
839 
840 /// Filling `dest` using FillOp constant padding value if possible.
841 /// Otherwise, generate a tensor::GenerateOp.
842 Value GeneralizePadOpPattern::createFillOrGenerateOp(
843     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
844     const SmallVector<Value> &dynSizes) const {
845   auto padValue = padOp.getConstantPaddingValue();
846   if (padValue)
847     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
848 
849   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
850   auto generateOp = rewriter.create<tensor::GenerateOp>(
851       padOp.getLoc(), padOp.getResultType(), dynSizes);
852   // Copy region to new op.
853   BlockAndValueMapping bvm;
854   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
855   return generateOp;
856 }
857 
858 LogicalResult
859 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
860                                         PatternRewriter &rewriter) const {
861   // Given an OpFoldResult, return an index-typed value.
862   auto getIdxValue = [&](OpFoldResult ofr) {
863     if (auto val = ofr.dyn_cast<Value>())
864       return val;
865     return rewriter
866         .create<arith::ConstantIndexOp>(
867             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
868         .getResult();
869   };
870 
871   auto resultType = padOp.getResultType();
872   // Compute size of InitTensorOp. Any combination of static/dynamic is
873   // supported.
874   SmallVector<Value> dynSizes;
875   SmallVector<int64_t> staticSizes;
876   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
877     if (resultType.isDynamicDim(dim)) {
878       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
879                                                           padOp.source(), dim);
880       // Add low and high padding value.
881       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
882           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
883       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
884           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
885       dynSizes.push_back(plusHigh);
886     }
887     staticSizes.push_back(resultType.getDimSize(dim));
888   }
889 
890   // Init tensor and fill it with padding.
891   Value init = rewriter.create<InitTensorOp>(
892       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
893   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
894 
895   // Try optimize the copy of source.
896   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
897     return success();
898 
899   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
900   // for copying the PadOp source.
901   auto sourceType = padOp.getSourceType();
902   // Compute size of source of tensor::PadOp.
903   SmallVector<OpFoldResult> srcSizes;
904   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
905     if (sourceType.isDynamicDim(dim)) {
906       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
907           padOp.getLoc(), padOp.source(), dim));
908     } else {
909       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
910     }
911   }
912   // Strides of InsertSliceOp are all 1.
913   SmallVector<OpFoldResult> strides(sourceType.getRank(),
914                                     rewriter.getIndexAttr(1));
915   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
916       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
917 
918   return success();
919 }
920 
921 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
922     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
923   if (!sliceOp.hasUnitStride())
924     return failure();
925 
926   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
927   if (!padOp)
928     return failure();
929 
930   bool zeroSliceGuard = true;
931   if (controlFn) {
932     if (Optional<bool> control = controlFn(sliceOp))
933       zeroSliceGuard = control.getValue();
934     else
935       return failure();
936   }
937 
938   Operation *tiledPadOp =
939       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
940                                sliceOp.getMixedSizes(), zeroSliceGuard);
941   // All shapes are static and the data source is actually used. Rewrite into
942   // pad(extract_slice(x)).
943   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
944   return success();
945 }
946 
947 namespace {
948 // The following are patterns for downscaling convolution ops with size-1
949 // window dimensions.
950 //
951 // Note that we'd eventually want to write such transformations in a generic
952 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
953 // and then turning back to named ops. But for now it's fine to have a few
954 // patterns matching special ops to get started.
955 
956 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
957 /// convolution ops.
958 struct DownscaleSizeOneWindowed2DConvolution final
959     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
960   DownscaleSizeOneWindowed2DConvolution(
961       MLIRContext *context,
962       LinalgTransformationFilter f = LinalgTransformationFilter(),
963       PatternBenefit benefit = 1)
964       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
965         filter(std::move(f)) {}
966 
967   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
968                                 PatternRewriter &rewriter) const override {
969     if (failed(filter.checkAndNotify(rewriter, convOp)))
970       return failure();
971     if (convOp.hasBufferSemantics())
972       return failure(); // To be implemented
973 
974     Value input = convOp.inputs().front();
975     Value kernel = convOp.inputs().back();
976     Value output = convOp.outputs().front();
977 
978     auto inputType = input.getType().dyn_cast<RankedTensorType>();
979     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
980     auto outputType = output.getType().dyn_cast<RankedTensorType>();
981 
982     auto kernelShape = kernelType.getShape();
983     auto outputShape = outputType.getShape();
984 
985     // Only handle the case where at least one of the window dimensions is
986     // of size 1. Other cases can rely on tiling to reduce to such cases.
987     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
988     int64_t ohSize = outputShape[1], owSize = outputShape[2];
989     bool removeH = (khSize == 1 && ohSize == 1);
990     bool removeW = (kwSize == 1 && owSize == 1);
991     if (!removeH && !removeW)
992       return failure();
993 
994     // Get new shapes and types for all operands by removing the size-1
995     // dimension.
996     using RTTBuilder = RankedTensorType::Builder;
997     RankedTensorType newInputType =
998         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
999     RankedTensorType newKernelType =
1000         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1001     RankedTensorType newOutputType =
1002         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1003 
1004     // Rank-reduce operands.
1005     Location loc = convOp.getLoc();
1006     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1007         rewriter, loc, input, newInputType);
1008     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1009         rewriter, loc, kernel, newKernelType);
1010     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1011         rewriter, loc, output, newOutputType);
1012 
1013     // Rank-reduce strides and dilations too.
1014     // TODO: dropDim 1-liner helper.
1015     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1016     strides.erase(strides.begin() + (removeH ? 0 : 1));
1017     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1018 
1019     auto dilations =
1020         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1021     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1022     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1023 
1024     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1025         loc, newOutputType, ValueRange{newInput, newKernel},
1026         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1027 
1028     // Insert back.
1029     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1030         rewriter, loc, conv1DOp.getResult(0), output);
1031     rewriter.replaceOp(convOp, inserted);
1032 
1033     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1034     return success();
1035   };
1036 
1037 private:
1038   /// LinalgTransformMarker handles special attribute manipulations.
1039   LinalgTransformationFilter filter;
1040 };
1041 
1042 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1043 /// dimensions into 1-D depthwise convolution ops.
1044 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1045     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1046   DownscaleDepthwiseConv2DNhwcHwcOp(
1047       MLIRContext *context,
1048       LinalgTransformationFilter f = LinalgTransformationFilter(),
1049       PatternBenefit benefit = 1)
1050       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1051         filter(std::move(f)) {}
1052 
1053   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1054                                 PatternRewriter &rewriter) const override {
1055     if (failed(filter.checkAndNotify(rewriter, convOp)))
1056       return failure();
1057     if (convOp.hasBufferSemantics())
1058       return failure(); // To be implemented
1059 
1060     Value input = convOp.inputs().front();
1061     Value kernel = convOp.inputs().back();
1062     Value output = convOp.outputs().front();
1063 
1064     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1065     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1066     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1067 
1068     auto kernelShape = kernelType.getShape();
1069     auto outputShape = outputType.getShape();
1070 
1071     // Only handle the case where at least one of the window dimensions is
1072     // of size 1. Other cases can rely on tiling to reduce to such cases.
1073     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1074     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1075     bool removeH = (khSize == 1 && ohSize == 1);
1076     bool removeW = (kwSize == 1 && owSize == 1);
1077     if (!removeH && !removeW)
1078       return failure();
1079 
1080     // Get new shapes and types for all operands by removing the size-1
1081     // dimension.
1082     using RTTBuilder = RankedTensorType::Builder;
1083     RankedTensorType newInputType =
1084         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1085     RankedTensorType newKernelType =
1086         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1087     RankedTensorType newOutputType =
1088         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1089 
1090     // Rank-reduce operands.
1091     Location loc = convOp.getLoc();
1092     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1093         rewriter, loc, input, newInputType);
1094     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1095         rewriter, loc, kernel, newKernelType);
1096     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1097         rewriter, loc, output, newOutputType);
1098 
1099     // Rank-reduce strides and dilations too.
1100     // TODO: dropDim 1-liner helper.
1101     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1102     strides.erase(strides.begin() + (removeH ? 0 : 1));
1103     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1104 
1105     auto dilations =
1106         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1107     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1108     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1109 
1110     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1111         loc, newOutputType, ValueRange{newInput, newKernel},
1112         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1113 
1114     // Insert back.
1115     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1116         rewriter, loc, conv1DOp.getResult(0), output);
1117     rewriter.replaceOp(convOp, inserted);
1118 
1119     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1120     return success();
1121   };
1122 
1123 private:
1124   /// LinalgTransformMarker handles special attribute manipulations.
1125   LinalgTransformationFilter filter;
1126 };
1127 
1128 } // namespace
1129 
1130 void linalg::populateDecomposeConvolutionPatterns(
1131     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1132     PatternBenefit benefit) {
1133   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1134                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1135                                                   benefit);
1136 }
1137