1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/SCF/Transforms.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include <type_traits>
37 #include <utility>
38 
39 #define DEBUG_TYPE "linalg-transforms"
40 
41 using namespace mlir;
42 using namespace mlir::linalg;
43 
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
45 
46 //===----------------------------------------------------------------------===//
47 // Transformations exposed as rewrite patterns.
48 //===----------------------------------------------------------------------===//
49 // Marker used as attribute name in generated Linalg rewriting transformations.
50 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
51     "__internal_linalg_transform__";
52 
53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
54     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
55     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
56       replacement(replacement), matchByDefault(false) {}
57 
58 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
59     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
60     Optional<StringAttr> replacement)
61     : filters(),
62       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
63       replacement(replacement), matchByDefault(false) {
64   if (f)
65     filters.push_back(f);
66 }
67 
68 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
69     PatternRewriter &rewriter, Operation *op) const {
70   if (llvm::any_of(filters,
71                    [&](const FilterFunction &f) { return failed(f(op)); }))
72     return failure();
73 
74   auto attr = op->template getAttrOfType<StringAttr>(
75       LinalgTransforms::kLinalgTransformMarker);
76 
77   if (!attr) {
78     // 1. Has no filter case and matchDisjunction is empty.
79     if (matchDisjunction.empty() || matchByDefault)
80       return success();
81 
82     // 2. Has no filter but was expecting a filter.
83     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
84       diag << " does not have any filter from list: ";
85       interleaveComma(matchDisjunction, diag);
86     });
87   }
88 
89   // 4. Match explicit filter.
90   for (auto filter : matchDisjunction)
91     if (attr.getValue() == filter)
92       return success();
93 
94   // 5. Fail to match.
95   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
96     diag << " does not have any filter from list: ";
97     interleaveComma(matchDisjunction, diag);
98   });
99 }
100 
101 void mlir::linalg::LinalgTransformationFilter::
102     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
103                                       Operation *op) const {
104   if (replacement.hasValue())
105     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
106                 replacement.getValue());
107   else
108     op->removeAttr(
109         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
110 }
111 
112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
113     Operation *op) const {
114   if (!replacement)
115     return false;
116   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
117                   .dyn_cast<StringAttr>();
118   return attr && attr == replacement.getValue();
119 }
120 
121 LinalgTilingOptions &
122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
123   assert(!tileSizeComputationFunction && "tile sizes already set");
124   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
125   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
126     OpBuilder::InsertionGuard guard(b);
127     b.setInsertionPointToStart(
128         &op->getParentOfType<FuncOp>().getBody().front());
129     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
130       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
131       return v;
132     }));
133   };
134   return *this;
135 }
136 
137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
138   assert(!tileSizeComputationFunction && "tile sizes already set");
139   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
140     SmallVector<Value, 4> tileSizes;
141     auto linalgOp = dyn_cast<LinalgOp>(op);
142     if (!linalgOp)
143       return tileSizes;
144     Location loc = linalgOp.getLoc();
145     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
146     AffineMap map = linalgOp.getShapesToLoopsMap();
147     if (!map)
148       return tileSizes;
149     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
150     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
151     // size 0).
152     for (Value shapeSize : shapeSizes)
153       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
154                               ? b.create<arith::ConstantIndexOp>(loc, 0)
155                               : b.create<arith::ConstantIndexOp>(loc, 1));
156     return tileSizes;
157   };
158   return *this;
159 }
160 
161 /// Helper function that tries to pad `opOperand`. Exit early for scalar
162 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
163 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
164 /// has a static shape. Set `result` to the result of the created tensor::PadOp
165 /// or and return success if the operand either has been padded to a static
166 /// shape or already had a static shape and failure otherwise.
167 static LogicalResult padOperandToSmallestStaticBoundingBox(
168     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
169     const PaddingValueComputationFunction &paddingFunc,
170     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
171   // Get the shape of the operand and check if it has a dynamic shape. Only
172   // return failure if the operand is not a scalar and has a dynamic shape.
173   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
174   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
175 
176   // Cannot pad scalar operands.
177   if (shape.empty())
178     return success();
179 
180   // Cannot pad if the padding value is unknown.
181   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
182   if (failed(paddingValue))
183     return failure(hasDynamicShape);
184 
185   // Cannot construct a static bounding box if the operand is not defined by an
186   // ExtractSliceOp.
187   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
188   if (!sliceOp)
189     return failure(hasDynamicShape);
190 
191   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
192   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
193 
194   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
195   SmallVector<int64_t> staticSizes;
196   staticSizes.reserve(shape.size());
197   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
198   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
199     // Skip dropped dimensions.
200     if (droppedDims.test(en.index()))
201       continue;
202     // If the size is an attribute add it directly to `staticSizes`.
203     if (en.value().is<Attribute>()) {
204       staticSizes.push_back(
205           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
206       continue;
207     }
208     // Otherwise, try to compute a constant upper bound for the size value.
209     FailureOr<int64_t> upperBound =
210         getConstantUpperBoundForIndex(en.value().get<Value>());
211     if (failed(upperBound)) {
212       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
213       return failure();
214     }
215     staticSizes.push_back(upperBound.getValue());
216   }
217   assert(staticSizes.size() == shape.size() &&
218          "expect the dynamic and static ranks to match");
219 
220   // Pad the operand to the bounding box defined by `staticSizes`.
221   auto staticTensorType = RankedTensorType::get(
222       staticSizes, getElementTypeOrSelf(opOperand->get()));
223   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
224   result =
225       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
226                             opOperand->get(), paddingValue.getValue(), nofold);
227   return success();
228 }
229 
230 FailureOr<SmallVector<Value>>
231 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
232                           const PaddingValueComputationFunction &paddingFunc,
233                           const PaddingNoFoldComputationFunction &nofoldFunc,
234                           LinalgOp &paddedOp) {
235   Location loc = opToPad->getLoc();
236 
237   // TODO: there are cases where we may still want to pad to larger sizes.
238   assert(opToPad.hasTensorSemantics() &&
239          "expected operation to have tensor semantics");
240 
241   OpBuilder::InsertionGuard g(b);
242   // Set IP after op because we also take the dims of the original output.
243   b.setInsertionPointAfter(opToPad);
244   // Make a copy of the shaped operands and update it.
245   SmallVector<Value> newOperands;
246   newOperands.reserve(opToPad.getNumInputsAndOutputs());
247   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
248     Value paddedOperand;
249     // If padding was requested but the shape cannot be bounded statically then
250     // the pattern fails to apply.
251     if (failed(padOperandToSmallestStaticBoundingBox(
252             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
253       return failure();
254     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
255   }
256 
257   SmallVector<SmallVector<Value>> reifiedResultShapes;
258   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
259                  .reifyResultShapes(b, reifiedResultShapes)))
260     return failure();
261   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
262          "expected same number of results");
263 
264   // Clone `opToPad` to operate on the statically padded shapes.
265   auto resultTensorTypes =
266       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
267   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
268 
269   // Recover the slice out of the new static results. This keeps the original
270   // linalg op around because it uses the dims of the original results.
271   SmallVector<Value> paddedSubviewResults;
272   paddedSubviewResults.reserve(opToPad->getNumResults());
273   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
274     Value paddedResult = en.value();
275     int64_t resultNumber = en.index();
276     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
277     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
278     SmallVector<OpFoldResult> sizes;
279     for (Value v : reifiedResultShapes[resultNumber])
280       sizes.push_back(getAsOpFoldResult(v));
281     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
282     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
283         loc, paddedResult, offsets, sizes, strides));
284   }
285   return paddedSubviewResults;
286 }
287 
288 /// Try to peel a loop `op` and return the new result.
289 // TODO: Add support for scf.parallel and affine.for loops.
290 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
291   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
292       .Case<scf::ForOp>([&](scf::ForOp forOp) {
293         scf::ForOp partialIteration;
294         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
295                                                       partialIteration)))
296           return partialIteration->getResults();
297         assert(!partialIteration && "expected that loop was not peeled");
298         return forOp->getResults();
299       })
300       .Default([&](Operation *op) { return op->getResults(); });
301 }
302 
303 /// Peel loops after tiling.
304 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
305                                      ArrayRef<int64_t> peeledLoops,
306                                      LinalgTilingLoopType loopType) {
307   for (int64_t loop : peeledLoops) {
308     assert(loop < static_cast<int64_t>(res.loops.size()) &&
309            "requested peeling of non-existing loop");
310     SmallVector<Value, 4> loopResults;
311     Operation *loopOp = res.loops[loop];
312     loopResults = peelLoop(rewriter, loopOp);
313 
314     // The result of the loop nest may change with peeling.
315     if (res.tensorResults.size() == loopOp->getNumResults() &&
316         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
317                    loopOp->getResults().begin()))
318       res.tensorResults = loopResults;
319   }
320 }
321 
322 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
323   if (tiledOp.loops.empty())
324     return tiledOp.op.getOperation()->getResults();
325   return tiledOp.loops.front()->getResults();
326 }
327 
328 static ValueRange
329 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
330   if (tiledAndFusedOp.fusedLoops.empty())
331     return tiledAndFusedOp.op.getOperation()->getResults();
332   return tiledAndFusedOp.fusedLoops.front()->getResults();
333 }
334 
335 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
336     StringRef opName, MLIRContext *context,
337     const LinalgDependenceGraph &dependenceGraph,
338     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
339     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
340     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
341     : RewritePattern(opName, benefit, context, {}),
342       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
343       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
344       fusedOpMarker(std::move(fusedOpMarker)),
345       originalOpMarker(std::move(originalOpMarker)) {}
346 
347 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
348     Operation *op, PatternRewriter &rewriter) const {
349   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
350   // TODO: remove hasIndexSemantics check once index ops are supported.
351   if (!linalgOp || linalgOp.hasIndexSemantics())
352     return failure();
353   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
354     return failure();
355 
356   DenseSet<Operation *> producers;
357   producers.insert(linalgOp);
358   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
359     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
360     // When looking at dependences into, indexingOp is always OpOperand. We
361     // could assert, but continue if this is not the case.
362     if (!operandNumber)
363       continue;
364     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
365       continue;
366     if (isa<LinalgOp>(dependence.getDependentOp()))
367       producers.insert(dependence.getDependentOp());
368   }
369 
370   SmallVector<LinalgOp, 1> fusionOps;
371   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
372        ++it) {
373     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
374     if (producerLinalgOp && producers.count(producerLinalgOp))
375       fusionOps.push_back(producerLinalgOp);
376   }
377   fusionOps.push_back(linalgOp);
378 
379   SmallVector<Value, 4> tileSizes =
380       tilingOptions.tileSizeComputationFunction(rewriter, op);
381   LinalgTilingOptions instanceTilingOptions = tilingOptions;
382   instanceTilingOptions.setTileSizes(tileSizes);
383   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
384       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
385   if (!tiledAndFusedOps)
386     return failure();
387 
388   // Tile the unfused loops;
389   SmallVector<Value, 4> unfusedLoopTileSizes;
390   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
391   for (const auto &tileSize : enumerate(tileSizes)) {
392     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
393       unfusedLoopTileSizes.push_back(zero);
394     else
395       unfusedLoopTileSizes.push_back(tileSize.value());
396   }
397   // Tile the loop only if there is a non-zero tile size.
398   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
399     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
400   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
401         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
402           return cst.value() != 0;
403         return true;
404       })) {
405     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
406     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
407     FailureOr<TiledLinalgOp> unfusedTiledOp =
408         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
409     if (failed(unfusedTiledOp))
410       return failure();
411     rewriter.replaceOp(tiledAndFusedOps->op,
412                        getTiledOpResult(unfusedTiledOp.getValue()));
413     tiledAndFusedOps->op = unfusedTiledOp->op;
414   }
415   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
416 
417   filter.replaceLinalgTransformationFilter(rewriter,
418                                            tiledAndFusedOps->op.getOperation());
419   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
420     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
421                                                     fusedOp.getOperation());
422   }
423   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
424     originalOpMarker.replaceLinalgTransformationFilter(
425         rewriter, origProducerOp.getOperation());
426   }
427   rewriter.updateRootInPlace(op, [&]() {
428     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
429   });
430   return success();
431 }
432 
433 /// Linalg tiling pattern.
434 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
435     MLIRContext *context, LinalgTilingOptions options,
436     LinalgTransformationFilter f, PatternBenefit benefit)
437     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
438       filter(std::move(f)), options(std::move(options)) {}
439 
440 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
441     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
442     LinalgTransformationFilter f, PatternBenefit benefit)
443     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
444       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
445 
446 FailureOr<TiledLinalgOp>
447 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
448     LinalgOp op, PatternRewriter &rewriter) const {
449   if (failed(filter.checkAndNotify(rewriter, op)))
450     return failure();
451 
452   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
453   if (failed(res))
454     return failure();
455 
456   // Clear filter to stop recursive pattern application.
457   // This must be done here to properly propagate to peeling branches.
458   filter.replaceLinalgTransformationFilter(rewriter, res->op);
459 
460   // Peel the loops of the TiledLinalgOp.
461   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
462 
463   if (res->tensorResults.empty())
464     rewriter.eraseOp(op);
465   else
466     rewriter.replaceOp(op, res->tensorResults);
467 
468   return res;
469 }
470 
471 /// Linalg padding pattern.
472 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
473     MLIRContext *context, LinalgPaddingOptions options,
474     LinalgTransformationFilter f, PatternBenefit benefit)
475     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
476       filter(std::move(f)), options(std::move(options)) {}
477 
478 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
479     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
480     LinalgTransformationFilter f, PatternBenefit benefit)
481     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
482       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
483 
484 FailureOr<LinalgOp>
485 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
486     LinalgOp linalgOp, PatternRewriter &rewriter) const {
487   if (!linalgOp.hasTensorSemantics())
488     return failure();
489   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
490     return failure();
491 
492   // Pad the operation.
493   LinalgOp paddedOp;
494   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
495       rewriter, linalgOp, options.paddingValueComputationFunction,
496       options.paddingNoFoldComputationFunction, paddedOp);
497   if (failed(newResults))
498     return failure();
499 
500   // Compute the desired hoisting depths.
501   SmallVector<int64_t> depths;
502   if (options.paddingHoistComputationFunction) {
503     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
504       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
505   }
506 
507   // Hoist the padding.
508   for (const auto &en : enumerate(depths)) {
509     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
510     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
511     if (!padOp || en.value() == 0)
512       continue;
513     tensor::PadOp hoistedOp;
514     SmallVector<GenericOp> transposeOps;
515     SmallVector<int64_t> transposeVector =
516         options.paddingTransposeComputationFunction(opOperand);
517 
518     FailureOr<Value> newResult = hoistPaddingOnTensors(
519         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
520     if (failed(newResult))
521       continue;
522     rewriter.replaceOp(padOp, newResult.getValue());
523 
524     // Do not apply hoist padding to the newly introduced transpose operations.
525     for (GenericOp transposeOp : transposeOps)
526       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
527   }
528 
529   // Replace the original operation to pad.
530   rewriter.replaceOp(linalgOp, newResults.getValue());
531   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
532 
533   return paddedOp;
534 }
535 
536 /// Linalg tile and fuse tensor ops pattern.
537 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
538     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
539                                       LinalgTilingAndFusionOptions options,
540                                       LinalgTransformationFilter f,
541                                       PatternBenefit benefit)
542     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
543       filter(std::move(f)), options(std::move(options)) {}
544 
545 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
546     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
547                                       LinalgTilingAndFusionOptions options,
548                                       LinalgTransformationFilter f,
549                                       PatternBenefit benefit)
550     : RewritePattern(opName, benefit, context), filter(std::move(f)),
551       options(std::move(options)) {}
552 
553 FailureOr<mlir::linalg::TileLoopNest>
554 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
555     Operation *op, PatternRewriter &rewriter) const {
556   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
557   if (!rootOp)
558     return failure();
559   if (failed(filter.checkAndNotify(rewriter, op)))
560     return failure();
561 
562   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
563   if (options.tileSizes.size() < rootOp.getNumLoops())
564     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
565 
566   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
567   if (!options.tileInterchange.empty() &&
568       options.tileInterchange.size() != options.tileSizes.size())
569     return rewriter.notifyMatchFailure(
570         op, "expect the number of tile sizes and interchange dims to match");
571 
572   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
573   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
574                                      options.tileSizes.begin() +
575                                          rootOp.getNumLoops());
576   SmallVector<int64_t> rootInterchange =
577       options.tileInterchange.empty()
578           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
579           : SmallVector<int64_t>(options.tileInterchange.begin(),
580                                  options.tileInterchange.begin() +
581                                      rootOp.getNumLoops());
582 
583   // Check `rootTileSizes` contains non-zero tile sizes.
584   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
585     return rewriter.notifyMatchFailure(
586         op, "expect at least one non-zero tile size");
587 
588   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
589   // It has to be a permutation since the tiling cannot tile the same loop
590   // dimension multiple times.
591   if (!isPermutation(rootInterchange))
592     return rewriter.notifyMatchFailure(
593         op, "expect the tile interchange permutes the root loops");
594 
595   // Tile `rootOp` and fuse its producers.
596   FailureOr<TileLoopNest> tileLoopNest =
597       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
598                                    rootInterchange, options.tileDistribution);
599   if (failed(tileLoopNest))
600     return rewriter.notifyMatchFailure(
601         op, "tileConsumerAndFuseProducers failed unexpectedly");
602 
603   // Replace all uses of the tiled loop operation.
604   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
605 
606   // Apply the filter if specified.
607   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
608     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
609   return tileLoopNest;
610 }
611 
612 /// Linalg generic interchange pattern.
613 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
614     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
615     LinalgTransformationFilter f, PatternBenefit benefit)
616     : OpRewritePattern(context, benefit), filter(std::move(f)),
617       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
618 
619 FailureOr<GenericOp>
620 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
621     GenericOp genericOp, PatternRewriter &rewriter) const {
622   if (failed(filter.checkAndNotify(rewriter, genericOp)))
623     return failure();
624 
625   FailureOr<GenericOp> transformedOp =
626       interchangeGenericOp(rewriter, genericOp, interchangeVector);
627   if (failed(transformedOp))
628     return failure();
629 
630   // New filter if specified.
631   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
632   return transformedOp;
633 }
634 
635 /// Linalg generalization pattern.
636 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
637     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
638     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
639       filter(std::move(f)) {}
640 
641 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
642     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
643     PatternBenefit benefit)
644     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
645       filter(f.addOpNameFilter(opName)) {}
646 
647 FailureOr<GenericOp>
648 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
649     LinalgOp linalgOp, PatternRewriter &rewriter) const {
650   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
651     return failure();
652   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
653   if (failed(genericOp))
654     return failure();
655   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
656   return genericOp;
657 }
658 
659 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
660     MLIRContext *context, LinalgTransformationFilter f,
661     LinalgPromotionOptions options, PatternBenefit benefit)
662     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
663       filter(std::move(f)), options(std::move(options)) {}
664 
665 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
666     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
667     LinalgTransformationFilter f, PatternBenefit benefit)
668     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
669       options(std::move(options)) {}
670 
671 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
672     Operation *op, PatternRewriter &rewriter) const {
673   if (failed(filter.checkAndNotify(rewriter, op)))
674     return failure();
675   if (failed(promoteSubviewsPrecondition(op, options)))
676     return failure();
677 
678   // TODO: We cannot use root update here. This pattern is creating other ops,
679   // so if the promotion fails, those need to be cleaned up, which doesnt seem
680   // to be happening here. So to fail properly, we should be cloning the op and
681   // deleting the previous op. This needs more investigation.
682   rewriter.startRootUpdate(op);
683   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
684   if (!promotedOp) {
685     rewriter.cancelRootUpdate(op);
686     return op->emitError("subview promotion failed");
687   }
688   rewriter.finalizeRootUpdate(op);
689   filter.replaceLinalgTransformationFilter(rewriter, op);
690   return success();
691 }
692 
693 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
694     MLIRContext *context, LinalgTransformationFilter f,
695     LinalgVectorizationOptions options, PatternBenefit benefit)
696     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
697       filter(std::move(f)) {}
698 
699 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
700     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
701     LinalgTransformationFilter f, PatternBenefit benefit)
702     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
703       filter(f.addOpNameFilter(opName)) {}
704 
705 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
706     LinalgOp linalgOp, PatternRewriter &rewriter) const {
707   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
708     return failure();
709   return vectorize(rewriter, linalgOp);
710 }
711 
712 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
713     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
714   return vectorizeCopy(rewriter, copyOp);
715 }
716 
717 LogicalResult mlir::linalg::applyStagedPatterns(
718     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
719     const FrozenRewritePatternSet &stage2Patterns,
720     function_ref<LogicalResult(Operation *)> stage3Lambda) {
721   unsigned iteration = 0;
722   (void)iteration;
723   for (const auto &patterns : stage1Patterns) {
724     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
725                       << *op);
726     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
727       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
728       return failure();
729     }
730     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
731                       << *op);
732     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
733       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
734       return failure();
735     }
736     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
737                       << *op);
738     if (stage3Lambda) {
739       if (failed(stage3Lambda(op)))
740         return failure();
741       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
742                         << *op);
743     }
744   }
745   return success();
746 }
747 
748 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
749   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
750 }
751 
752 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
753 /// initialize with pad_val) and GenericOp (to copy contents).
754 LogicalResult
755 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
756                                             PatternRewriter &rewriter) const {
757 
758   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
759   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
760 
761   // Bail on non-static shapes.
762   if (!inputShapedType.hasStaticShape())
763     return failure();
764   if (!resultShapedType.hasStaticShape())
765     return failure();
766 
767   // Only support padding with a constant for now, i.e. either:
768   //   1. A BBarg from a different block.
769   //   2. A value defined outside of the current block.
770   Block &block = padOp.region().front();
771   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
772   Value padValue = yieldOp.value();
773   Operation *definingOp = padValue.getDefiningOp();
774   if (definingOp && definingOp->getBlock() == &block)
775     return failure();
776   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
777     return failure();
778 
779   // Create tensor with the padded shape
780   Location loc = padOp.getLoc();
781   SmallVector<Value> indices(resultShapedType.getRank(),
782                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
783   Value initTensor = rewriter.create<InitTensorOp>(
784       loc, resultShapedType.getShape(), resultShapedType.getElementType());
785 
786   // Initialize tensor with the pad value
787   Value tmpTensor = rewriter
788                         .create<linalg::FillOp>(loc, ValueRange{padValue},
789                                                 ValueRange{initTensor})
790                         .result();
791 
792   // Copy original contents into new tensor
793   // Uses linalg.generic, but could be done with tensor.insert_slice
794   SmallVector<AffineExpr, 4> outputExprs;
795   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
796     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
797                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
798   }
799 
800   SmallVector<AffineMap, 2> transferMaps = {
801       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
802       AffineMap::get(resultShapedType.getRank(),
803                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
804 
805   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
806       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
807       getNParallelLoopsAttrs(resultShapedType.getRank()),
808       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
809         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
810       });
811 
812   return success();
813 }
814 
815 /// Filling `dest` using FillOp constant padding value if possible.
816 /// Otherwise, generate a tensor::GenerateOp.
817 Value GeneralizePadOpPattern::createFillOrGenerateOp(
818     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
819     const SmallVector<Value> &dynSizes) const {
820   auto padValue = padOp.getConstantPaddingValue();
821   if (padValue)
822     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
823 
824   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
825   auto generateOp = rewriter.create<tensor::GenerateOp>(
826       padOp.getLoc(), padOp.getResultType(), dynSizes);
827   // Copy region to new op.
828   BlockAndValueMapping bvm;
829   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
830   return generateOp;
831 }
832 
833 LogicalResult
834 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
835                                         PatternRewriter &rewriter) const {
836   // Given an OpFoldResult, return an index-typed value.
837   auto getIdxValue = [&](OpFoldResult ofr) {
838     if (auto val = ofr.dyn_cast<Value>())
839       return val;
840     return rewriter
841         .create<arith::ConstantIndexOp>(
842             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
843         .getResult();
844   };
845 
846   auto resultType = padOp.getResultType();
847   // Compute size of InitTensorOp. Any combination of static/dynamic is
848   // supported.
849   SmallVector<Value> dynSizes;
850   SmallVector<int64_t> staticSizes;
851   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
852     if (resultType.isDynamicDim(dim)) {
853       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
854                                                           padOp.source(), dim);
855       // Add low and high padding value.
856       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
857           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
858       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
859           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
860       dynSizes.push_back(plusHigh);
861     }
862     staticSizes.push_back(resultType.getDimSize(dim));
863   }
864 
865   // Init tensor and fill it with padding.
866   Value init = rewriter.create<InitTensorOp>(
867       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
868   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
869 
870   // Try optimize the copy of source.
871   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
872     return success();
873 
874   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
875   // for copying the PadOp source.
876   auto sourceType = padOp.getSourceType();
877   // Compute size of source of tensor::PadOp.
878   SmallVector<OpFoldResult> srcSizes;
879   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
880     if (sourceType.isDynamicDim(dim)) {
881       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
882           padOp.getLoc(), padOp.source(), dim));
883     } else {
884       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
885     }
886   }
887   // Strides of InsertSliceOp are all 1.
888   SmallVector<OpFoldResult> strides(sourceType.getRank(),
889                                     rewriter.getIndexAttr(1));
890   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
891       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
892 
893   return success();
894 }
895 
896 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
897     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
898   if (!sliceOp.hasUnitStride())
899     return failure();
900 
901   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
902   if (!padOp)
903     return failure();
904 
905   bool zeroSliceGuard = true;
906   if (controlFn) {
907     if (Optional<bool> control = controlFn(sliceOp))
908       zeroSliceGuard = control.getValue();
909     else
910       return failure();
911   }
912 
913   Operation *tiledPadOp =
914       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
915                                sliceOp.getMixedSizes(), zeroSliceGuard);
916   // All shapes are static and the data source is actually used. Rewrite into
917   // pad(extract_slice(x)).
918   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
919   return success();
920 }
921 
922 namespace {
923 // The following are patterns for downscaling convolution ops with size-1
924 // window dimensions.
925 //
926 // Note that we'd eventually want to write such transformations in a generic
927 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
928 // and then turning back to named ops. But for now it's fine to have a few
929 // patterns matching special ops to get started.
930 
931 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
932 /// convolution ops.
933 struct DownscaleSizeOneWindowed2DConvolution final
934     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
935   DownscaleSizeOneWindowed2DConvolution(
936       MLIRContext *context,
937       LinalgTransformationFilter f = LinalgTransformationFilter(),
938       PatternBenefit benefit = 1)
939       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
940         filter(std::move(f)) {}
941 
942   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
943                                 PatternRewriter &rewriter) const override {
944     if (failed(filter.checkAndNotify(rewriter, convOp)))
945       return failure();
946     if (convOp.hasBufferSemantics())
947       return failure(); // To be implemented
948 
949     Value input = convOp.inputs().front();
950     Value kernel = convOp.inputs().back();
951     Value output = convOp.outputs().front();
952 
953     auto inputType = input.getType().dyn_cast<RankedTensorType>();
954     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
955     auto outputType = output.getType().dyn_cast<RankedTensorType>();
956 
957     auto kernelShape = kernelType.getShape();
958     auto outputShape = outputType.getShape();
959 
960     // Only handle the case where at least one of the window dimensions is
961     // of size 1. Other cases can rely on tiling to reduce to such cases.
962     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
963     int64_t ohSize = outputShape[1], owSize = outputShape[2];
964     bool removeH = (khSize == 1 && ohSize == 1);
965     bool removeW = (kwSize == 1 && owSize == 1);
966     if (!removeH && !removeW)
967       return failure();
968 
969     // Get new shapes and types for all operands by removing the size-1
970     // dimension.
971     using RTTBuilder = RankedTensorType::Builder;
972     RankedTensorType newInputType =
973         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
974     RankedTensorType newKernelType =
975         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
976     RankedTensorType newOutputType =
977         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
978 
979     // Rank-reduce operands.
980     Location loc = convOp.getLoc();
981     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
982         rewriter, loc, input, newInputType);
983     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
984         rewriter, loc, kernel, newKernelType);
985     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
986         rewriter, loc, output, newOutputType);
987 
988     // Rank-reduce strides and dilations too.
989     // TODO: dropDim 1-liner helper.
990     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
991     strides.erase(strides.begin() + (removeH ? 0 : 1));
992     auto stridesAttr = rewriter.getI64VectorAttr(strides);
993 
994     auto dilations =
995         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
996     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
997     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
998 
999     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1000         loc, newOutputType, ValueRange{newInput, newKernel},
1001         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1002 
1003     // Insert back.
1004     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1005         rewriter, loc, conv1DOp.getResult(0), output);
1006     rewriter.replaceOp(convOp, inserted);
1007 
1008     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1009     return success();
1010   };
1011 
1012 private:
1013   /// LinalgTransformMarker handles special attribute manipulations.
1014   LinalgTransformationFilter filter;
1015 };
1016 
1017 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1018 /// dimensions into 1-D depthwise convolution ops.
1019 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1020     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1021   DownscaleDepthwiseConv2DNhwcHwcOp(
1022       MLIRContext *context,
1023       LinalgTransformationFilter f = LinalgTransformationFilter(),
1024       PatternBenefit benefit = 1)
1025       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1026         filter(std::move(f)) {}
1027 
1028   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1029                                 PatternRewriter &rewriter) const override {
1030     if (failed(filter.checkAndNotify(rewriter, convOp)))
1031       return failure();
1032     if (convOp.hasBufferSemantics())
1033       return failure(); // To be implemented
1034 
1035     Value input = convOp.inputs().front();
1036     Value kernel = convOp.inputs().back();
1037     Value output = convOp.outputs().front();
1038 
1039     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1040     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1041     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1042 
1043     auto kernelShape = kernelType.getShape();
1044     auto outputShape = outputType.getShape();
1045 
1046     // Only handle the case where at least one of the window dimensions is
1047     // of size 1. Other cases can rely on tiling to reduce to such cases.
1048     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1049     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1050     bool removeH = (khSize == 1 && ohSize == 1);
1051     bool removeW = (kwSize == 1 && owSize == 1);
1052     if (!removeH && !removeW)
1053       return failure();
1054 
1055     // Get new shapes and types for all operands by removing the size-1
1056     // dimension.
1057     using RTTBuilder = RankedTensorType::Builder;
1058     RankedTensorType newInputType =
1059         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1060     RankedTensorType newKernelType =
1061         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1062     RankedTensorType newOutputType =
1063         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1064 
1065     // Rank-reduce operands.
1066     Location loc = convOp.getLoc();
1067     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1068         rewriter, loc, input, newInputType);
1069     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1070         rewriter, loc, kernel, newKernelType);
1071     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1072         rewriter, loc, output, newOutputType);
1073 
1074     // Rank-reduce strides and dilations too.
1075     // TODO: dropDim 1-liner helper.
1076     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1077     strides.erase(strides.begin() + (removeH ? 0 : 1));
1078     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1079 
1080     auto dilations =
1081         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1082     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1083     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1084 
1085     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1086         loc, newOutputType, ValueRange{newInput, newKernel},
1087         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1088 
1089     // Insert back.
1090     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1091         rewriter, loc, conv1DOp.getResult(0), output);
1092     rewriter.replaceOp(convOp, inserted);
1093 
1094     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1095     return success();
1096   };
1097 
1098 private:
1099   /// LinalgTransformMarker handles special attribute manipulations.
1100   LinalgTransformationFilter filter;
1101 };
1102 
1103 } // namespace
1104 
1105 void linalg::populateDecomposeConvolutionPatterns(
1106     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1107     PatternBenefit benefit) {
1108   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1109                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1110                                                   benefit);
1111 }
1112