1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StaticValueUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 #include <utility>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
58     Optional<StringAttr> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 replacement.getValue());
105   else
106     op->removeAttr(
107         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
108 }
109 
110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
111     Operation *op) const {
112   if (!replacement)
113     return false;
114   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
115                   .dyn_cast<StringAttr>();
116   return attr && attr == replacement.getValue();
117 }
118 
119 LinalgTilingOptions &
120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
121   assert(!tileSizeComputationFunction && "tile sizes already set");
122   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124     OpBuilder::InsertionGuard guard(b);
125     b.setInsertionPointToStart(
126         &op->getParentOfType<FuncOp>().getBody().front());
127     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129       return v;
130     }));
131   };
132   return *this;
133 }
134 
135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
136   assert(!tileSizeComputationFunction && "tile sizes already set");
137   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138     SmallVector<Value, 4> tileSizes;
139     auto linalgOp = dyn_cast<LinalgOp>(op);
140     if (!linalgOp)
141       return tileSizes;
142     Location loc = linalgOp.getLoc();
143     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144     AffineMap map = linalgOp.getShapesToLoopsMap();
145     if (!map)
146       return tileSizes;
147     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149     // size 0).
150     for (Value shapeSize : shapeSizes)
151       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152                               ? b.create<arith::ConstantIndexOp>(loc, 0)
153                               : b.create<arith::ConstantIndexOp>(loc, 1));
154     return tileSizes;
155   };
156   return *this;
157 }
158 
159 /// Helper function that tries to pad `opOperand`. Exit early for scalar
160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162 /// has a static shape. Set `result` to the result of the created tensor::PadOp
163 /// or and return success if the operand either has been padded to a static
164 /// shape or already had a static shape and failure otherwise.
165 static LogicalResult padOperandToSmallestStaticBoundingBox(
166     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167     const PaddingValueComputationFunction &paddingFunc,
168     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169   // Get the shape of the operand and check if it has a dynamic shape. Only
170   // return failure if the operand is not a scalar and has a dynamic shape.
171   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173 
174   // Cannot pad scalar operands.
175   if (shape.empty())
176     return success();
177 
178   // Cannot pad if the padding value is unknown.
179   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180   if (failed(paddingValue))
181     return failure(hasDynamicShape);
182 
183   // Cannot construct a static bounding box if the operand is not defined by an
184   // ExtractSliceOp.
185   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186   if (!sliceOp)
187     return failure(hasDynamicShape);
188 
189   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
190   llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
191 
192   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
193   SmallVector<int64_t> staticSizes;
194   staticSizes.reserve(shape.size());
195   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
196   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
197     // Skip dropped dimensions.
198     if (droppedDims.contains(en.index()))
199       continue;
200     // If the size is an attribute add it directly to `staticSizes`.
201     if (en.value().is<Attribute>()) {
202       staticSizes.push_back(
203           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
204       continue;
205     }
206     // Otherwise, try to compute a constant upper bound for the size value.
207     FailureOr<int64_t> upperBound =
208         getConstantUpperBoundForIndex(en.value().get<Value>());
209     if (failed(upperBound)) {
210       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
211       return failure();
212     }
213     staticSizes.push_back(upperBound.getValue());
214   }
215   assert(staticSizes.size() == shape.size() &&
216          "expect the dynamic and static ranks to match");
217 
218   // Pad the operand to the bounding box defined by `staticSizes`.
219   auto staticTensorType = RankedTensorType::get(
220       staticSizes, getElementTypeOrSelf(opOperand->get()));
221   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
222   result =
223       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
224                             opOperand->get(), paddingValue.getValue(), nofold);
225   return success();
226 }
227 
228 FailureOr<SmallVector<Value>>
229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
230                           const PaddingValueComputationFunction &paddingFunc,
231                           const PaddingNoFoldComputationFunction &nofoldFunc,
232                           LinalgOp &paddedOp) {
233   Location loc = opToPad->getLoc();
234 
235   // TODO: there are cases where we may still want to pad to larger sizes.
236   assert(opToPad.hasTensorSemantics() &&
237          "expected operation to have tensor semantics");
238 
239   OpBuilder::InsertionGuard g(b);
240   // Set IP after op because we also take the dims of the original output.
241   b.setInsertionPointAfter(opToPad);
242   // Make a copy of the shaped operands and update it.
243   SmallVector<Value> newOperands;
244   newOperands.reserve(opToPad.getNumInputsAndOutputs());
245   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
246     Value paddedOperand;
247     // If padding was requested but the shape cannot be bounded statically then
248     // the pattern fails to apply.
249     if (failed(padOperandToSmallestStaticBoundingBox(
250             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
251       return failure();
252     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
253   }
254 
255   SmallVector<SmallVector<Value>> reifiedResultShapes;
256   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
257                  .reifyResultShapes(b, reifiedResultShapes)))
258     return failure();
259   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
260          "expected same number of results");
261 
262   // Clone `opToPad` to operate on the statically padded shapes.
263   auto resultTensorTypes =
264       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
265   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
266 
267   // Recover the slice out of the new static results. This keeps the original
268   // linalg op around because it uses the dims of the original results.
269   SmallVector<Value> paddedSubviewResults;
270   paddedSubviewResults.reserve(opToPad->getNumResults());
271   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
272     Value paddedResult = en.value();
273     int64_t resultNumber = en.index();
274     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
275     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
276     SmallVector<OpFoldResult> sizes;
277     for (Value v : reifiedResultShapes[resultNumber])
278       sizes.push_back(getAsOpFoldResult(v));
279     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
280     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
281         loc, paddedResult, offsets, sizes, strides));
282   }
283   return paddedSubviewResults;
284 }
285 
286 /// Try to peel a loop `op` and return the new result.
287 // TODO: Add support for scf.parallel and affine.for loops.
288 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
289   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
290       .Case<scf::ForOp>([&](scf::ForOp forOp) {
291         scf::ForOp partialIteration;
292         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
293                                                       partialIteration)))
294           return partialIteration->getResults();
295         assert(!partialIteration && "expected that loop was not peeled");
296         return forOp->getResults();
297       })
298       .Default([&](Operation *op) { return op->getResults(); });
299 }
300 
301 /// Try to peel a TiledLoopOp and return the new result.
302 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
303                                       TiledLoopOp tiledLoop, int64_t idx) {
304   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
305          "requested peeling of non-existing loop");
306   TiledLoopOp result;
307   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
308     return result->getResults();
309   assert(!result && "expected that loop was not peeled");
310   return tiledLoop->getResults();
311 }
312 
313 /// Peel loops after tiling.
314 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
315                                      ArrayRef<int64_t> peeledLoops,
316                                      LinalgTilingLoopType loopType) {
317   for (int64_t loop : peeledLoops) {
318     assert(loop < static_cast<int64_t>(res.loops.size()) &&
319            "requested peeling of non-existing loop");
320     SmallVector<Value, 4> loopResults;
321     Operation *loopOp = res.loops[loop];
322     if (loopType == LinalgTilingLoopType::TiledLoops) {
323       assert(llvm::all_of(
324                  res.loops,
325                  [&](Operation *op) { return op == res.loops.front(); }) &&
326              "expected that all loop ops are the same TiledLoopOp");
327       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
328       assert(tiledLoopOp && "expected TiledLoopOp");
329       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
330     } else {
331       loopResults = peelLoop(rewriter, loopOp);
332     }
333 
334     // The result of the loop nest may change with peeling.
335     if (res.tensorResults.size() == loopOp->getNumResults() &&
336         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
337                    loopOp->getResults().begin()))
338       res.tensorResults = loopResults;
339   }
340 }
341 
342 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
343   if (tiledOp.loops.empty())
344     return tiledOp.op.getOperation()->getResults();
345   return tiledOp.loops.front()->getResults();
346 }
347 
348 static ValueRange
349 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
350   if (tiledAndFusedOp.fusedLoops.empty())
351     return tiledAndFusedOp.op.getOperation()->getResults();
352   return tiledAndFusedOp.fusedLoops.front()->getResults();
353 }
354 
355 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
356     StringRef opName, MLIRContext *context,
357     const LinalgDependenceGraph &dependenceGraph,
358     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
359     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
360     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
361     : RewritePattern(opName, benefit, context, {}),
362       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
363       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
364       fusedOpMarker(std::move(fusedOpMarker)),
365       originalOpMarker(std::move(originalOpMarker)) {}
366 
367 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
368     Operation *op, PatternRewriter &rewriter) const {
369   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
370   // TODO: remove hasIndexSemantics check once index ops are supported.
371   if (!linalgOp || linalgOp.hasIndexSemantics())
372     return failure();
373   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
374     return failure();
375 
376   DenseSet<Operation *> producers;
377   producers.insert(linalgOp);
378   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
379     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
380     // When looking at dependences into, indexingOp is always OpOperand. We
381     // could assert, but continue if this is not the case.
382     if (!operandNumber)
383       continue;
384     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
385       continue;
386     if (isa<LinalgOp>(dependence.getDependentOp()))
387       producers.insert(dependence.getDependentOp());
388   }
389 
390   SmallVector<LinalgOp, 1> fusionOps;
391   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
392        ++it) {
393     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
394     if (producerLinalgOp && producers.count(producerLinalgOp))
395       fusionOps.push_back(producerLinalgOp);
396   }
397   fusionOps.push_back(linalgOp);
398 
399   SmallVector<Value, 4> tileSizes =
400       tilingOptions.tileSizeComputationFunction(rewriter, op);
401   LinalgTilingOptions instanceTilingOptions = tilingOptions;
402   instanceTilingOptions.setTileSizes(tileSizes);
403   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
404       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
405   if (!tiledAndFusedOps)
406     return failure();
407 
408   // Tile the unfused loops;
409   SmallVector<Value, 4> unfusedLoopTileSizes;
410   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
411   for (const auto &tileSize : enumerate(tileSizes)) {
412     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
413       unfusedLoopTileSizes.push_back(zero);
414     else
415       unfusedLoopTileSizes.push_back(tileSize.value());
416   }
417   // Tile the loop only if there is a non-zero tile size.
418   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
419     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
420   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
421         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
422           return cst.value() != 0;
423         return true;
424       })) {
425     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
426     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
427     FailureOr<TiledLinalgOp> unfusedTiledOp =
428         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
429     if (failed(unfusedTiledOp))
430       return failure();
431     rewriter.replaceOp(tiledAndFusedOps->op,
432                        getTiledOpResult(unfusedTiledOp.getValue()));
433     tiledAndFusedOps->op = unfusedTiledOp->op;
434   }
435   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
436 
437   filter.replaceLinalgTransformationFilter(rewriter,
438                                            tiledAndFusedOps->op.getOperation());
439   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
440     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
441                                                     fusedOp.getOperation());
442   }
443   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
444     originalOpMarker.replaceLinalgTransformationFilter(
445         rewriter, origProducerOp.getOperation());
446   }
447   rewriter.updateRootInPlace(op, [&]() {
448     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
449   });
450   return success();
451 }
452 
453 /// Linalg tiling pattern.
454 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
455     MLIRContext *context, LinalgTilingOptions options,
456     LinalgTransformationFilter f, PatternBenefit benefit)
457     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458       filter(std::move(f)), options(std::move(options)) {}
459 
460 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
461     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
462     LinalgTransformationFilter f, PatternBenefit benefit)
463     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
464       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
465 
466 FailureOr<TiledLinalgOp>
467 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
468     LinalgOp op, PatternRewriter &rewriter) const {
469   if (failed(filter.checkAndNotify(rewriter, op)))
470     return failure();
471 
472   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
473   if (failed(res))
474     return failure();
475 
476   // Clear filter to stop recursive pattern application.
477   // This must be done here to properly propagate to peeling branches.
478   filter.replaceLinalgTransformationFilter(rewriter, res->op);
479 
480   // Peel the loops of the TiledLinalgOp.
481   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
482 
483   if (res->tensorResults.empty())
484     rewriter.eraseOp(op);
485   else
486     rewriter.replaceOp(op, res->tensorResults);
487 
488   return res;
489 }
490 
491 /// Linalg padding pattern.
492 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
493     MLIRContext *context, LinalgPaddingOptions options,
494     LinalgTransformationFilter f, PatternBenefit benefit)
495     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
496       filter(std::move(f)), options(std::move(options)) {}
497 
498 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
499     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
500     LinalgTransformationFilter f, PatternBenefit benefit)
501     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
502       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
503 
504 FailureOr<LinalgOp>
505 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
506     LinalgOp linalgOp, PatternRewriter &rewriter) const {
507   if (!linalgOp.hasTensorSemantics())
508     return failure();
509   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
510     return failure();
511 
512   // Pad the operation.
513   LinalgOp paddedOp;
514   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
515       rewriter, linalgOp, options.paddingValueComputationFunction,
516       options.paddingNoFoldComputationFunction, paddedOp);
517   if (failed(newResults))
518     return failure();
519 
520   // Compute the desired hoisting depths.
521   SmallVector<int64_t> depths;
522   if (options.paddingHoistComputationFunction) {
523     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
524       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
525   }
526 
527   // Hoist the padding.
528   for (const auto &en : enumerate(depths)) {
529     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
530     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
531     if (!padOp || en.value() == 0)
532       continue;
533     tensor::PadOp hoistedOp;
534     SmallVector<GenericOp> transposeOps;
535     SmallVector<int64_t> transposeVector =
536         options.paddingTransposeComputationFunction(opOperand);
537 
538     FailureOr<Value> newResult = hoistPaddingOnTensors(
539         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
540     if (failed(newResult))
541       continue;
542     rewriter.replaceOp(padOp, newResult.getValue());
543 
544     // Do not apply hoist padding to the newly introduced transpose operations.
545     for (GenericOp transposeOp : transposeOps)
546       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
547   }
548 
549   // Replace the original operation to pad.
550   rewriter.replaceOp(linalgOp, newResults.getValue());
551   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
552 
553   return paddedOp;
554 }
555 
556 /// Linalg tile and fuse tensor ops pattern.
557 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
558     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
559                                       LinalgTilingAndFusionOptions options,
560                                       LinalgTransformationFilter f,
561                                       PatternBenefit benefit)
562     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
563       filter(std::move(f)), options(std::move(options)) {}
564 
565 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
566     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
567                                       LinalgTilingAndFusionOptions options,
568                                       LinalgTransformationFilter f,
569                                       PatternBenefit benefit)
570     : RewritePattern(opName, benefit, context), filter(std::move(f)),
571       options(std::move(options)) {}
572 
573 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
574     Operation *op, PatternRewriter &rewriter) const {
575   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
576   if (!rootOp)
577     return failure();
578   if (failed(filter.checkAndNotify(rewriter, op)))
579     return failure();
580 
581   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
582   if (options.tileSizes.size() < rootOp.getNumLoops())
583     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
584 
585   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
586   if (!options.tileInterchange.empty() &&
587       options.tileInterchange.size() != options.tileSizes.size())
588     return rewriter.notifyMatchFailure(
589         op, "expect the number of tile sizes and interchange dims to match");
590 
591   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
592   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
593                                      options.tileSizes.begin() +
594                                          rootOp.getNumLoops());
595   if (llvm::all_of(rootTileSizes, [](int64_t ts) { return ts == 0; })) {
596     return rewriter.notifyMatchFailure(
597         op, "all tile sizes are zero, nothing to do");
598   }
599   SmallVector<int64_t> rootInterchange =
600       options.tileInterchange.empty()
601           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
602           : SmallVector<int64_t>(options.tileInterchange.begin(),
603                                  options.tileInterchange.begin() +
604                                      rootOp.getNumLoops());
605 
606   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
607   // It has to be a permutation since the tiling cannot tile the same loop
608   // dimension multiple times.
609   if (!isPermutation(rootInterchange))
610     return rewriter.notifyMatchFailure(
611         op, "expect the tile interchange permutes the root loops");
612 
613   // Tile `rootOp` and fuse its producers.
614   FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
615       rewriter, rootOp, rootTileSizes, rootInterchange);
616   if (failed(tileLoopNest))
617     return rewriter.notifyMatchFailure(
618         op, "tileConsumerAndFuseProducers failed unexpectedly");
619 
620   // Replace all uses of the tiled loop operation.
621   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
622 
623   // Apply the filter if specified.
624   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
625     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
626   return failure();
627 }
628 
629 /// Linalg generic interchange pattern.
630 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
631     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
632     LinalgTransformationFilter f, PatternBenefit benefit)
633     : OpRewritePattern(context, benefit), filter(std::move(f)),
634       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
635 
636 FailureOr<GenericOp>
637 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
638     GenericOp genericOp, PatternRewriter &rewriter) const {
639   if (failed(filter.checkAndNotify(rewriter, genericOp)))
640     return failure();
641 
642   FailureOr<GenericOp> transformedOp =
643       interchangeGenericOp(rewriter, genericOp, interchangeVector);
644   if (failed(transformedOp))
645     return failure();
646 
647   // New filter if specified.
648   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
649   return transformedOp;
650 }
651 
652 /// Linalg generalization pattern.
653 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
654     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
655     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
656       filter(std::move(f)) {}
657 
658 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
659     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
660     PatternBenefit benefit)
661     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
662       filter(f.addOpNameFilter(opName)) {}
663 
664 FailureOr<GenericOp>
665 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
666     LinalgOp linalgOp, PatternRewriter &rewriter) const {
667   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
668     return failure();
669   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
670   if (failed(genericOp))
671     return failure();
672   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
673   return genericOp;
674 }
675 
676 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
677     MLIRContext *context, LinalgTransformationFilter f,
678     LinalgPromotionOptions options, PatternBenefit benefit)
679     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
680       filter(std::move(f)), options(std::move(options)) {}
681 
682 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
683     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
684     LinalgTransformationFilter f, PatternBenefit benefit)
685     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
686       options(std::move(options)) {}
687 
688 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
689     Operation *op, PatternRewriter &rewriter) const {
690   if (failed(filter.checkAndNotify(rewriter, op)))
691     return failure();
692   if (failed(promoteSubviewsPrecondition(op, options)))
693     return failure();
694 
695   // TODO: We cannot use root update here. This pattern is creating other ops,
696   // so if the promotion fails, those need to be cleaned up, which doesnt seem
697   // to be happening here. So to fail properly, we should be cloning the op and
698   // deleting the previous op. This needs more investigation.
699   rewriter.startRootUpdate(op);
700   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
701   if (!promotedOp) {
702     rewriter.cancelRootUpdate(op);
703     return op->emitError("subview promotion failed");
704   }
705   rewriter.finalizeRootUpdate(op);
706   filter.replaceLinalgTransformationFilter(rewriter, op);
707   return success();
708 }
709 
710 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
711     MLIRContext *context, LinalgTransformationFilter f,
712     LinalgVectorizationOptions options, PatternBenefit benefit)
713     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
714       filter(std::move(f)) {}
715 
716 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
717     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
718     LinalgTransformationFilter f, PatternBenefit benefit)
719     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
720       filter(f.addOpNameFilter(opName)) {}
721 
722 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
723     LinalgOp linalgOp, PatternRewriter &rewriter) const {
724   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
725     return failure();
726   return vectorize(rewriter, linalgOp);
727 }
728 
729 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
730     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
731   return vectorizeCopy(rewriter, copyOp);
732 }
733 
734 LogicalResult mlir::linalg::applyStagedPatterns(
735     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
736     const FrozenRewritePatternSet &stage2Patterns,
737     function_ref<LogicalResult(Operation *)> stage3Lambda) {
738   unsigned iteration = 0;
739   (void)iteration;
740   for (const auto &patterns : stage1Patterns) {
741     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
742                       << *op);
743     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
744       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
745       return failure();
746     }
747     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
748                       << *op);
749     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
750       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
751       return failure();
752     }
753     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
754                       << *op);
755     if (stage3Lambda) {
756       if (failed(stage3Lambda(op)))
757         return failure();
758       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
759                         << *op);
760     }
761   }
762   return success();
763 }
764 
765 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
766   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
767 }
768 
769 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
770 /// initialize with pad_val) and GenericOp (to copy contents).
771 LogicalResult
772 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
773                                             PatternRewriter &rewriter) const {
774 
775   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
776   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
777 
778   // Bail on non-static shapes.
779   if (!inputShapedType.hasStaticShape())
780     return failure();
781   if (!resultShapedType.hasStaticShape())
782     return failure();
783 
784   // Only support padding with a constant for now, i.e. either:
785   //   1. A BBarg from a different block.
786   //   2. A value defined outside of the current block.
787   Block &block = padOp.region().front();
788   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
789   Value padValue = yieldOp.value();
790   Operation *definingOp = padValue.getDefiningOp();
791   if (definingOp && definingOp->getBlock() == &block)
792     return failure();
793   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
794     return failure();
795 
796   // Create tensor with the padded shape
797   Location loc = padOp.getLoc();
798   SmallVector<Value> indices(resultShapedType.getRank(),
799                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
800   Value initTensor = rewriter.create<InitTensorOp>(
801       loc, resultShapedType.getShape(), resultShapedType.getElementType());
802 
803   // Initialize tensor with the pad value
804   Value tmpTensor =
805       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
806 
807   // Copy original contents into new tensor
808   // Uses linalg.generic, but could be done with tensor.insert_slice
809   SmallVector<AffineExpr, 4> outputExprs;
810   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
811     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
812                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
813   }
814 
815   SmallVector<AffineMap, 2> transferMaps = {
816       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
817       AffineMap::get(resultShapedType.getRank(),
818                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
819 
820   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
821       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
822       getNParallelLoopsAttrs(resultShapedType.getRank()),
823       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
824         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
825       });
826 
827   return success();
828 }
829 
830 /// Filling `dest` using FillOp constant padding value if possible.
831 /// Otherwise, generate a tensor::GenerateOp.
832 Value GeneralizePadOpPattern::createFillOrGenerateOp(
833     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
834     const SmallVector<Value> &dynSizes) const {
835   auto padValue = padOp.getConstantPaddingValue();
836   if (padValue)
837     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
838 
839   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
840   auto generateOp = rewriter.create<tensor::GenerateOp>(
841       padOp.getLoc(), padOp.getResultType(), dynSizes);
842   // Copy region to new op.
843   BlockAndValueMapping bvm;
844   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
845   return generateOp;
846 }
847 
848 LogicalResult
849 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
850                                         PatternRewriter &rewriter) const {
851   // Given an OpFoldResult, return an index-typed value.
852   auto getIdxValue = [&](OpFoldResult ofr) {
853     if (auto val = ofr.dyn_cast<Value>())
854       return val;
855     return rewriter
856         .create<arith::ConstantIndexOp>(
857             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
858         .getResult();
859   };
860 
861   auto resultType = padOp.getResultType();
862   // Compute size of InitTensorOp. Any combination of static/dynamic is
863   // supported.
864   SmallVector<Value> dynSizes;
865   SmallVector<int64_t> staticSizes;
866   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
867     if (resultType.isDynamicDim(dim)) {
868       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
869                                                           padOp.source(), dim);
870       // Add low and high padding value.
871       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
872           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
873       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
874           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
875       dynSizes.push_back(plusHigh);
876     }
877     staticSizes.push_back(resultType.getDimSize(dim));
878   }
879 
880   // Init tensor and fill it with padding.
881   Value init = rewriter.create<InitTensorOp>(
882       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
883   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
884 
885   // Try optimize the copy of source.
886   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
887     return success();
888 
889   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
890   // for copying the PadOp source.
891   auto sourceType = padOp.getSourceType();
892   // Compute size of source of tensor::PadOp.
893   SmallVector<OpFoldResult> srcSizes;
894   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
895     if (sourceType.isDynamicDim(dim)) {
896       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
897           padOp.getLoc(), padOp.source(), dim));
898     } else {
899       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
900     }
901   }
902   // Strides of InsertSliceOp are all 1.
903   SmallVector<OpFoldResult> strides(sourceType.getRank(),
904                                     rewriter.getIndexAttr(1));
905   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
906       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
907 
908   return success();
909 }
910 
911 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
912     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
913   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
914   if (!padOp)
915     return failure();
916   // Only unit stride supported.
917   if (!sliceOp.hasUnitStride())
918     return failure();
919 
920   TilingInterface tilingInterface =
921       dyn_cast<TilingInterface>(padOp.getOperation());
922   Operation *tiledPadOp =
923       tilingInterface
924           .getTiledImplementation(
925               rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
926               sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
927           .front();
928   // All shapes are static and the data source is actually used. Rewrite into
929   // pad_tensor(subtensor(x)).
930   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
931   return success();
932 }
933 
934 namespace {
935 // The following are patterns for downscaling convolution ops with size-1
936 // window dimensions.
937 //
938 // Note that we'd eventually want to write such transformations in a generic
939 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
940 // and then turning back to named ops. But for now it's fine to have a few
941 // patterns matching special ops to get started.
942 
943 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
944 /// convolution ops.
945 struct DownscaleSizeOneWindowed2DConvolution final
946     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
947   DownscaleSizeOneWindowed2DConvolution(
948       MLIRContext *context,
949       LinalgTransformationFilter f = LinalgTransformationFilter(),
950       PatternBenefit benefit = 1)
951       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
952         filter(std::move(f)) {}
953 
954   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
955                                 PatternRewriter &rewriter) const override {
956     if (failed(filter.checkAndNotify(rewriter, convOp)))
957       return failure();
958     if (convOp.hasBufferSemantics())
959       return failure(); // To be implemented
960 
961     Value input = convOp.inputs().front();
962     Value kernel = convOp.inputs().back();
963     Value output = convOp.outputs().front();
964 
965     auto inputType = input.getType().dyn_cast<RankedTensorType>();
966     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
967     auto outputType = output.getType().dyn_cast<RankedTensorType>();
968 
969     auto kernelShape = kernelType.getShape();
970     auto outputShape = outputType.getShape();
971 
972     // Only handle the case where at least one of the window dimensions is
973     // of size 1. Other cases can rely on tiling to reduce to such cases.
974     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
975     int64_t ohSize = outputShape[1], owSize = outputShape[2];
976     bool removeH = (khSize == 1 && ohSize == 1);
977     bool removeW = (kwSize == 1 && owSize == 1);
978     if (!removeH && !removeW)
979       return failure();
980 
981     // Get new shapes and types for all operands by removing the size-1
982     // dimension.
983     using RTTBuilder = RankedTensorType::Builder;
984     RankedTensorType newInputType =
985         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
986     RankedTensorType newKernelType =
987         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
988     RankedTensorType newOutputType =
989         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
990 
991     // Rank-reduce operands.
992     Location loc = convOp.getLoc();
993     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
994         rewriter, loc, input, newInputType);
995     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
996         rewriter, loc, kernel, newKernelType);
997     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
998         rewriter, loc, output, newOutputType);
999 
1000     // Rank-reduce strides and dilations too.
1001     // TODO: dropDim 1-liner helper.
1002     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1003     strides.erase(strides.begin() + (removeH ? 0 : 1));
1004     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1005 
1006     auto dilations =
1007         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1008     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1009     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1010 
1011     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1012         loc, newOutputType, ValueRange{newInput, newKernel},
1013         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1014 
1015     // Insert back.
1016     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1017         rewriter, loc, conv1DOp.getResult(0), output);
1018     rewriter.replaceOp(convOp, inserted);
1019 
1020     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1021     return success();
1022   };
1023 
1024 private:
1025   /// LinalgTransformMarker handles special attribute manipulations.
1026   LinalgTransformationFilter filter;
1027 };
1028 
1029 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1030 /// dimensions into 1-D depthwise convolution ops.
1031 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1032     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1033   DownscaleDepthwiseConv2DNhwcHwcOp(
1034       MLIRContext *context,
1035       LinalgTransformationFilter f = LinalgTransformationFilter(),
1036       PatternBenefit benefit = 1)
1037       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1038         filter(std::move(f)) {}
1039 
1040   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1041                                 PatternRewriter &rewriter) const override {
1042     if (failed(filter.checkAndNotify(rewriter, convOp)))
1043       return failure();
1044     if (convOp.hasBufferSemantics())
1045       return failure(); // To be implemented
1046 
1047     Value input = convOp.inputs().front();
1048     Value kernel = convOp.inputs().back();
1049     Value output = convOp.outputs().front();
1050 
1051     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1052     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1053     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1054 
1055     auto kernelShape = kernelType.getShape();
1056     auto outputShape = outputType.getShape();
1057 
1058     // Only handle the case where at least one of the window dimensions is
1059     // of size 1. Other cases can rely on tiling to reduce to such cases.
1060     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1061     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1062     bool removeH = (khSize == 1 && ohSize == 1);
1063     bool removeW = (kwSize == 1 && owSize == 1);
1064     if (!removeH && !removeW)
1065       return failure();
1066 
1067     // Get new shapes and types for all operands by removing the size-1
1068     // dimension.
1069     using RTTBuilder = RankedTensorType::Builder;
1070     RankedTensorType newInputType =
1071         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1072     RankedTensorType newKernelType =
1073         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1074     RankedTensorType newOutputType =
1075         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1076 
1077     // Rank-reduce operands.
1078     Location loc = convOp.getLoc();
1079     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1080         rewriter, loc, input, newInputType);
1081     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1082         rewriter, loc, kernel, newKernelType);
1083     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1084         rewriter, loc, output, newOutputType);
1085 
1086     // Rank-reduce strides and dilations too.
1087     // TODO: dropDim 1-liner helper.
1088     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1089     strides.erase(strides.begin() + (removeH ? 0 : 1));
1090     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1091 
1092     auto dilations =
1093         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1094     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1095     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1096 
1097     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1098         loc, newOutputType, ValueRange{newInput, newKernel},
1099         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1100 
1101     // Insert back.
1102     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1103         rewriter, loc, conv1DOp.getResult(0), output);
1104     rewriter.replaceOp(convOp, inserted);
1105 
1106     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1107     return success();
1108   };
1109 
1110 private:
1111   /// LinalgTransformMarker handles special attribute manipulations.
1112   LinalgTransformationFilter filter;
1113 };
1114 
1115 } // namespace
1116 
1117 void linalg::populateDecomposeConvolutionPatterns(
1118     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1119     PatternBenefit benefit) {
1120   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1121                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1122                                                   benefit);
1123 }
1124