1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StaticValueUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 #include <utility>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
58     Optional<StringAttr> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 replacement.getValue());
105   else
106     op->removeAttr(
107         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
108 }
109 
110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
111     Operation *op) const {
112   if (!replacement)
113     return false;
114   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
115                   .dyn_cast<StringAttr>();
116   return attr && attr == replacement.getValue();
117 }
118 
119 LinalgTilingOptions &
120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
121   assert(!tileSizeComputationFunction && "tile sizes already set");
122   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124     OpBuilder::InsertionGuard guard(b);
125     b.setInsertionPointToStart(
126         &op->getParentOfType<FuncOp>().getBody().front());
127     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129       return v;
130     }));
131   };
132   return *this;
133 }
134 
135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
136   assert(!tileSizeComputationFunction && "tile sizes already set");
137   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138     SmallVector<Value, 4> tileSizes;
139     auto linalgOp = dyn_cast<LinalgOp>(op);
140     if (!linalgOp)
141       return tileSizes;
142     Location loc = linalgOp.getLoc();
143     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144     AffineMap map = linalgOp.getShapesToLoopsMap();
145     if (!map)
146       return tileSizes;
147     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149     // size 0).
150     for (Value shapeSize : shapeSizes)
151       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152                               ? b.create<arith::ConstantIndexOp>(loc, 0)
153                               : b.create<arith::ConstantIndexOp>(loc, 1));
154     return tileSizes;
155   };
156   return *this;
157 }
158 
159 /// Helper function that tries to pad `opOperand`. Exit early for scalar
160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162 /// has a static shape. Set `result` to the result of the created tensor::PadOp
163 /// or and return success if the operand either has been padded to a static
164 /// shape or already had a static shape and failure otherwise.
165 static LogicalResult padOperandToSmallestStaticBoundingBox(
166     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167     const PaddingValueComputationFunction &paddingFunc,
168     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169   // Get the shape of the operand and check if it has a dynamic shape. Only
170   // return failure if the operand is not a scalar and has a dynamic shape.
171   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173 
174   // Cannot pad scalar operands.
175   if (shape.empty())
176     return success();
177 
178   // Cannot pad if the padding value is unknown.
179   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180   if (failed(paddingValue))
181     return failure(hasDynamicShape);
182 
183   // Cannot construct a static bounding box if the operand is not defined by an
184   // ExtractSliceOp.
185   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186   if (!sliceOp)
187     return failure(hasDynamicShape);
188 
189   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
190   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
191 
192   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
193   SmallVector<int64_t> staticSizes;
194   staticSizes.reserve(shape.size());
195   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
196   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
197     // Skip dropped dimensions.
198     if (droppedDims.test(en.index()))
199       continue;
200     // If the size is an attribute add it directly to `staticSizes`.
201     if (en.value().is<Attribute>()) {
202       staticSizes.push_back(
203           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
204       continue;
205     }
206     // Otherwise, try to compute a constant upper bound for the size value.
207     FailureOr<int64_t> upperBound =
208         getConstantUpperBoundForIndex(en.value().get<Value>());
209     if (failed(upperBound)) {
210       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
211       return failure();
212     }
213     staticSizes.push_back(upperBound.getValue());
214   }
215   assert(staticSizes.size() == shape.size() &&
216          "expect the dynamic and static ranks to match");
217 
218   // Pad the operand to the bounding box defined by `staticSizes`.
219   auto staticTensorType = RankedTensorType::get(
220       staticSizes, getElementTypeOrSelf(opOperand->get()));
221   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
222   result =
223       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
224                             opOperand->get(), paddingValue.getValue(), nofold);
225   return success();
226 }
227 
228 FailureOr<SmallVector<Value>>
229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
230                           const PaddingValueComputationFunction &paddingFunc,
231                           const PaddingNoFoldComputationFunction &nofoldFunc,
232                           LinalgOp &paddedOp) {
233   Location loc = opToPad->getLoc();
234 
235   // TODO: there are cases where we may still want to pad to larger sizes.
236   assert(opToPad.hasTensorSemantics() &&
237          "expected operation to have tensor semantics");
238 
239   OpBuilder::InsertionGuard g(b);
240   // Set IP after op because we also take the dims of the original output.
241   b.setInsertionPointAfter(opToPad);
242   // Make a copy of the shaped operands and update it.
243   SmallVector<Value> newOperands;
244   newOperands.reserve(opToPad.getNumInputsAndOutputs());
245   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
246     Value paddedOperand;
247     // If padding was requested but the shape cannot be bounded statically then
248     // the pattern fails to apply.
249     if (failed(padOperandToSmallestStaticBoundingBox(
250             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
251       return failure();
252     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
253   }
254 
255   SmallVector<SmallVector<Value>> reifiedResultShapes;
256   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
257                  .reifyResultShapes(b, reifiedResultShapes)))
258     return failure();
259   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
260          "expected same number of results");
261 
262   // Clone `opToPad` to operate on the statically padded shapes.
263   auto resultTensorTypes =
264       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
265   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
266 
267   // Recover the slice out of the new static results. This keeps the original
268   // linalg op around because it uses the dims of the original results.
269   SmallVector<Value> paddedSubviewResults;
270   paddedSubviewResults.reserve(opToPad->getNumResults());
271   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
272     Value paddedResult = en.value();
273     int64_t resultNumber = en.index();
274     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
275     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
276     SmallVector<OpFoldResult> sizes;
277     for (Value v : reifiedResultShapes[resultNumber])
278       sizes.push_back(getAsOpFoldResult(v));
279     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
280     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
281         loc, paddedResult, offsets, sizes, strides));
282   }
283   return paddedSubviewResults;
284 }
285 
286 /// Try to peel a loop `op` and return the new result.
287 // TODO: Add support for scf.parallel and affine.for loops.
288 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
289   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
290       .Case<scf::ForOp>([&](scf::ForOp forOp) {
291         scf::ForOp partialIteration;
292         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
293                                                       partialIteration)))
294           return partialIteration->getResults();
295         assert(!partialIteration && "expected that loop was not peeled");
296         return forOp->getResults();
297       })
298       .Default([&](Operation *op) { return op->getResults(); });
299 }
300 
301 /// Try to peel a TiledLoopOp and return the new result.
302 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
303                                       TiledLoopOp tiledLoop, int64_t idx) {
304   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
305          "requested peeling of non-existing loop");
306   TiledLoopOp result;
307   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
308     return result->getResults();
309   assert(!result && "expected that loop was not peeled");
310   return tiledLoop->getResults();
311 }
312 
313 /// Peel loops after tiling.
314 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
315                                      ArrayRef<int64_t> peeledLoops,
316                                      LinalgTilingLoopType loopType) {
317   for (int64_t loop : peeledLoops) {
318     assert(loop < static_cast<int64_t>(res.loops.size()) &&
319            "requested peeling of non-existing loop");
320     SmallVector<Value, 4> loopResults;
321     Operation *loopOp = res.loops[loop];
322     if (loopType == LinalgTilingLoopType::TiledLoops) {
323       assert(llvm::all_of(
324                  res.loops,
325                  [&](Operation *op) { return op == res.loops.front(); }) &&
326              "expected that all loop ops are the same TiledLoopOp");
327       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
328       assert(tiledLoopOp && "expected TiledLoopOp");
329       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
330     } else {
331       loopResults = peelLoop(rewriter, loopOp);
332     }
333 
334     // The result of the loop nest may change with peeling.
335     if (res.tensorResults.size() == loopOp->getNumResults() &&
336         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
337                    loopOp->getResults().begin()))
338       res.tensorResults = loopResults;
339   }
340 }
341 
342 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
343   if (tiledOp.loops.empty())
344     return tiledOp.op.getOperation()->getResults();
345   return tiledOp.loops.front()->getResults();
346 }
347 
348 static ValueRange
349 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
350   if (tiledAndFusedOp.fusedLoops.empty())
351     return tiledAndFusedOp.op.getOperation()->getResults();
352   return tiledAndFusedOp.fusedLoops.front()->getResults();
353 }
354 
355 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
356     StringRef opName, MLIRContext *context,
357     const LinalgDependenceGraph &dependenceGraph,
358     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
359     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
360     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
361     : RewritePattern(opName, benefit, context, {}),
362       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
363       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
364       fusedOpMarker(std::move(fusedOpMarker)),
365       originalOpMarker(std::move(originalOpMarker)) {}
366 
367 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
368     Operation *op, PatternRewriter &rewriter) const {
369   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
370   // TODO: remove hasIndexSemantics check once index ops are supported.
371   if (!linalgOp || linalgOp.hasIndexSemantics())
372     return failure();
373   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
374     return failure();
375 
376   DenseSet<Operation *> producers;
377   producers.insert(linalgOp);
378   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
379     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
380     // When looking at dependences into, indexingOp is always OpOperand. We
381     // could assert, but continue if this is not the case.
382     if (!operandNumber)
383       continue;
384     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
385       continue;
386     if (isa<LinalgOp>(dependence.getDependentOp()))
387       producers.insert(dependence.getDependentOp());
388   }
389 
390   SmallVector<LinalgOp, 1> fusionOps;
391   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
392        ++it) {
393     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
394     if (producerLinalgOp && producers.count(producerLinalgOp))
395       fusionOps.push_back(producerLinalgOp);
396   }
397   fusionOps.push_back(linalgOp);
398 
399   SmallVector<Value, 4> tileSizes =
400       tilingOptions.tileSizeComputationFunction(rewriter, op);
401   LinalgTilingOptions instanceTilingOptions = tilingOptions;
402   instanceTilingOptions.setTileSizes(tileSizes);
403   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
404       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
405   if (!tiledAndFusedOps)
406     return failure();
407 
408   // Tile the unfused loops;
409   SmallVector<Value, 4> unfusedLoopTileSizes;
410   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
411   for (const auto &tileSize : enumerate(tileSizes)) {
412     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
413       unfusedLoopTileSizes.push_back(zero);
414     else
415       unfusedLoopTileSizes.push_back(tileSize.value());
416   }
417   // Tile the loop only if there is a non-zero tile size.
418   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
419     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
420   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
421         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
422           return cst.value() != 0;
423         return true;
424       })) {
425     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
426     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
427     FailureOr<TiledLinalgOp> unfusedTiledOp =
428         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
429     if (failed(unfusedTiledOp))
430       return failure();
431     rewriter.replaceOp(tiledAndFusedOps->op,
432                        getTiledOpResult(unfusedTiledOp.getValue()));
433     tiledAndFusedOps->op = unfusedTiledOp->op;
434   }
435   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
436 
437   filter.replaceLinalgTransformationFilter(rewriter,
438                                            tiledAndFusedOps->op.getOperation());
439   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
440     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
441                                                     fusedOp.getOperation());
442   }
443   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
444     originalOpMarker.replaceLinalgTransformationFilter(
445         rewriter, origProducerOp.getOperation());
446   }
447   rewriter.updateRootInPlace(op, [&]() {
448     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
449   });
450   return success();
451 }
452 
453 /// Linalg tiling pattern.
454 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
455     MLIRContext *context, LinalgTilingOptions options,
456     LinalgTransformationFilter f, PatternBenefit benefit)
457     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458       filter(std::move(f)), options(std::move(options)) {}
459 
460 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
461     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
462     LinalgTransformationFilter f, PatternBenefit benefit)
463     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
464       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
465 
466 FailureOr<TiledLinalgOp>
467 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
468     LinalgOp op, PatternRewriter &rewriter) const {
469   if (failed(filter.checkAndNotify(rewriter, op)))
470     return failure();
471 
472   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
473   if (failed(res))
474     return failure();
475 
476   // Clear filter to stop recursive pattern application.
477   // This must be done here to properly propagate to peeling branches.
478   filter.replaceLinalgTransformationFilter(rewriter, res->op);
479 
480   // Peel the loops of the TiledLinalgOp.
481   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
482 
483   if (res->tensorResults.empty())
484     rewriter.eraseOp(op);
485   else
486     rewriter.replaceOp(op, res->tensorResults);
487 
488   return res;
489 }
490 
491 /// Linalg padding pattern.
492 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
493     MLIRContext *context, LinalgPaddingOptions options,
494     LinalgTransformationFilter f, PatternBenefit benefit)
495     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
496       filter(std::move(f)), options(std::move(options)) {}
497 
498 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
499     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
500     LinalgTransformationFilter f, PatternBenefit benefit)
501     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
502       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
503 
504 FailureOr<LinalgOp>
505 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
506     LinalgOp linalgOp, PatternRewriter &rewriter) const {
507   if (!linalgOp.hasTensorSemantics())
508     return failure();
509   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
510     return failure();
511 
512   // Pad the operation.
513   LinalgOp paddedOp;
514   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
515       rewriter, linalgOp, options.paddingValueComputationFunction,
516       options.paddingNoFoldComputationFunction, paddedOp);
517   if (failed(newResults))
518     return failure();
519 
520   // Compute the desired hoisting depths.
521   SmallVector<int64_t> depths;
522   if (options.paddingHoistComputationFunction) {
523     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
524       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
525   }
526 
527   // Hoist the padding.
528   for (const auto &en : enumerate(depths)) {
529     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
530     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
531     if (!padOp || en.value() == 0)
532       continue;
533     tensor::PadOp hoistedOp;
534     SmallVector<GenericOp> transposeOps;
535     SmallVector<int64_t> transposeVector =
536         options.paddingTransposeComputationFunction(opOperand);
537 
538     FailureOr<Value> newResult = hoistPaddingOnTensors(
539         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
540     if (failed(newResult))
541       continue;
542     rewriter.replaceOp(padOp, newResult.getValue());
543 
544     // Do not apply hoist padding to the newly introduced transpose operations.
545     for (GenericOp transposeOp : transposeOps)
546       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
547   }
548 
549   // Replace the original operation to pad.
550   rewriter.replaceOp(linalgOp, newResults.getValue());
551   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
552 
553   return paddedOp;
554 }
555 
556 /// Linalg tile and fuse tensor ops pattern.
557 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
558     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
559                                       LinalgTilingAndFusionOptions options,
560                                       LinalgTransformationFilter f,
561                                       PatternBenefit benefit)
562     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
563       filter(std::move(f)), options(std::move(options)) {}
564 
565 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
566     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
567                                       LinalgTilingAndFusionOptions options,
568                                       LinalgTransformationFilter f,
569                                       PatternBenefit benefit)
570     : RewritePattern(opName, benefit, context), filter(std::move(f)),
571       options(std::move(options)) {}
572 
573 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
574     Operation *op, PatternRewriter &rewriter) const {
575   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
576   if (!rootOp)
577     return failure();
578   if (failed(filter.checkAndNotify(rewriter, op)))
579     return failure();
580 
581   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
582   if (options.tileSizes.size() < rootOp.getNumLoops())
583     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
584 
585   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
586   if (!options.tileInterchange.empty() &&
587       options.tileInterchange.size() != options.tileSizes.size())
588     return rewriter.notifyMatchFailure(
589         op, "expect the number of tile sizes and interchange dims to match");
590 
591   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
592   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
593                                      options.tileSizes.begin() +
594                                          rootOp.getNumLoops());
595   SmallVector<int64_t> rootInterchange =
596       options.tileInterchange.empty()
597           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
598           : SmallVector<int64_t>(options.tileInterchange.begin(),
599                                  options.tileInterchange.begin() +
600                                      rootOp.getNumLoops());
601 
602   // Check `rootTileSizes` contains non-zero tile sizes.
603   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
604     return rewriter.notifyMatchFailure(
605         op, "expect at least one non-zero tile size");
606 
607   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
608   // It has to be a permutation since the tiling cannot tile the same loop
609   // dimension multiple times.
610   if (!isPermutation(rootInterchange))
611     return rewriter.notifyMatchFailure(
612         op, "expect the tile interchange permutes the root loops");
613 
614   // Tile `rootOp` and fuse its producers.
615   FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
616       rewriter, rootOp, rootTileSizes, rootInterchange);
617   if (failed(tileLoopNest))
618     return rewriter.notifyMatchFailure(
619         op, "tileConsumerAndFuseProducers failed unexpectedly");
620 
621   // Replace all uses of the tiled loop operation.
622   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
623 
624   // Apply the filter if specified.
625   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
626     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
627   return success();
628 }
629 
630 /// Linalg generic interchange pattern.
631 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
632     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
633     LinalgTransformationFilter f, PatternBenefit benefit)
634     : OpRewritePattern(context, benefit), filter(std::move(f)),
635       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
636 
637 FailureOr<GenericOp>
638 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
639     GenericOp genericOp, PatternRewriter &rewriter) const {
640   if (failed(filter.checkAndNotify(rewriter, genericOp)))
641     return failure();
642 
643   FailureOr<GenericOp> transformedOp =
644       interchangeGenericOp(rewriter, genericOp, interchangeVector);
645   if (failed(transformedOp))
646     return failure();
647 
648   // New filter if specified.
649   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
650   return transformedOp;
651 }
652 
653 /// Linalg generalization pattern.
654 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
655     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
656     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
657       filter(std::move(f)) {}
658 
659 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
660     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
661     PatternBenefit benefit)
662     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
663       filter(f.addOpNameFilter(opName)) {}
664 
665 FailureOr<GenericOp>
666 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
667     LinalgOp linalgOp, PatternRewriter &rewriter) const {
668   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
669     return failure();
670   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
671   if (failed(genericOp))
672     return failure();
673   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
674   return genericOp;
675 }
676 
677 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
678     MLIRContext *context, LinalgTransformationFilter f,
679     LinalgPromotionOptions options, PatternBenefit benefit)
680     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
681       filter(std::move(f)), options(std::move(options)) {}
682 
683 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
684     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
685     LinalgTransformationFilter f, PatternBenefit benefit)
686     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
687       options(std::move(options)) {}
688 
689 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
690     Operation *op, PatternRewriter &rewriter) const {
691   if (failed(filter.checkAndNotify(rewriter, op)))
692     return failure();
693   if (failed(promoteSubviewsPrecondition(op, options)))
694     return failure();
695 
696   // TODO: We cannot use root update here. This pattern is creating other ops,
697   // so if the promotion fails, those need to be cleaned up, which doesnt seem
698   // to be happening here. So to fail properly, we should be cloning the op and
699   // deleting the previous op. This needs more investigation.
700   rewriter.startRootUpdate(op);
701   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
702   if (!promotedOp) {
703     rewriter.cancelRootUpdate(op);
704     return op->emitError("subview promotion failed");
705   }
706   rewriter.finalizeRootUpdate(op);
707   filter.replaceLinalgTransformationFilter(rewriter, op);
708   return success();
709 }
710 
711 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
712     MLIRContext *context, LinalgTransformationFilter f,
713     LinalgVectorizationOptions options, PatternBenefit benefit)
714     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
715       filter(std::move(f)) {}
716 
717 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
718     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
719     LinalgTransformationFilter f, PatternBenefit benefit)
720     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
721       filter(f.addOpNameFilter(opName)) {}
722 
723 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
724     LinalgOp linalgOp, PatternRewriter &rewriter) const {
725   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
726     return failure();
727   return vectorize(rewriter, linalgOp);
728 }
729 
730 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
731     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
732   return vectorizeCopy(rewriter, copyOp);
733 }
734 
735 LogicalResult mlir::linalg::applyStagedPatterns(
736     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
737     const FrozenRewritePatternSet &stage2Patterns,
738     function_ref<LogicalResult(Operation *)> stage3Lambda) {
739   unsigned iteration = 0;
740   (void)iteration;
741   for (const auto &patterns : stage1Patterns) {
742     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
743                       << *op);
744     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
745       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
746       return failure();
747     }
748     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
749                       << *op);
750     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
751       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
752       return failure();
753     }
754     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
755                       << *op);
756     if (stage3Lambda) {
757       if (failed(stage3Lambda(op)))
758         return failure();
759       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
760                         << *op);
761     }
762   }
763   return success();
764 }
765 
766 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
767   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
768 }
769 
770 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
771 /// initialize with pad_val) and GenericOp (to copy contents).
772 LogicalResult
773 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
774                                             PatternRewriter &rewriter) const {
775 
776   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
777   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
778 
779   // Bail on non-static shapes.
780   if (!inputShapedType.hasStaticShape())
781     return failure();
782   if (!resultShapedType.hasStaticShape())
783     return failure();
784 
785   // Only support padding with a constant for now, i.e. either:
786   //   1. A BBarg from a different block.
787   //   2. A value defined outside of the current block.
788   Block &block = padOp.region().front();
789   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
790   Value padValue = yieldOp.value();
791   Operation *definingOp = padValue.getDefiningOp();
792   if (definingOp && definingOp->getBlock() == &block)
793     return failure();
794   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
795     return failure();
796 
797   // Create tensor with the padded shape
798   Location loc = padOp.getLoc();
799   SmallVector<Value> indices(resultShapedType.getRank(),
800                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
801   Value initTensor = rewriter.create<InitTensorOp>(
802       loc, resultShapedType.getShape(), resultShapedType.getElementType());
803 
804   // Initialize tensor with the pad value
805   Value tmpTensor =
806       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
807 
808   // Copy original contents into new tensor
809   // Uses linalg.generic, but could be done with tensor.insert_slice
810   SmallVector<AffineExpr, 4> outputExprs;
811   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
812     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
813                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
814   }
815 
816   SmallVector<AffineMap, 2> transferMaps = {
817       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
818       AffineMap::get(resultShapedType.getRank(),
819                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
820 
821   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
822       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
823       getNParallelLoopsAttrs(resultShapedType.getRank()),
824       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
825         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
826       });
827 
828   return success();
829 }
830 
831 /// Filling `dest` using FillOp constant padding value if possible.
832 /// Otherwise, generate a tensor::GenerateOp.
833 Value GeneralizePadOpPattern::createFillOrGenerateOp(
834     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
835     const SmallVector<Value> &dynSizes) const {
836   auto padValue = padOp.getConstantPaddingValue();
837   if (padValue)
838     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
839 
840   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
841   auto generateOp = rewriter.create<tensor::GenerateOp>(
842       padOp.getLoc(), padOp.getResultType(), dynSizes);
843   // Copy region to new op.
844   BlockAndValueMapping bvm;
845   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
846   return generateOp;
847 }
848 
849 LogicalResult
850 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
851                                         PatternRewriter &rewriter) const {
852   // Given an OpFoldResult, return an index-typed value.
853   auto getIdxValue = [&](OpFoldResult ofr) {
854     if (auto val = ofr.dyn_cast<Value>())
855       return val;
856     return rewriter
857         .create<arith::ConstantIndexOp>(
858             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
859         .getResult();
860   };
861 
862   auto resultType = padOp.getResultType();
863   // Compute size of InitTensorOp. Any combination of static/dynamic is
864   // supported.
865   SmallVector<Value> dynSizes;
866   SmallVector<int64_t> staticSizes;
867   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
868     if (resultType.isDynamicDim(dim)) {
869       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
870                                                           padOp.source(), dim);
871       // Add low and high padding value.
872       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
873           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
874       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
875           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
876       dynSizes.push_back(plusHigh);
877     }
878     staticSizes.push_back(resultType.getDimSize(dim));
879   }
880 
881   // Init tensor and fill it with padding.
882   Value init = rewriter.create<InitTensorOp>(
883       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
884   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
885 
886   // Try optimize the copy of source.
887   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
888     return success();
889 
890   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
891   // for copying the PadOp source.
892   auto sourceType = padOp.getSourceType();
893   // Compute size of source of tensor::PadOp.
894   SmallVector<OpFoldResult> srcSizes;
895   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
896     if (sourceType.isDynamicDim(dim)) {
897       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
898           padOp.getLoc(), padOp.source(), dim));
899     } else {
900       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
901     }
902   }
903   // Strides of InsertSliceOp are all 1.
904   SmallVector<OpFoldResult> strides(sourceType.getRank(),
905                                     rewriter.getIndexAttr(1));
906   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
907       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
908 
909   return success();
910 }
911 
912 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
913     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
914   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
915   if (!padOp)
916     return failure();
917   // Only unit stride supported.
918   if (!sliceOp.hasUnitStride())
919     return failure();
920 
921   TilingInterface tilingInterface =
922       dyn_cast<TilingInterface>(padOp.getOperation());
923   Operation *tiledPadOp =
924       tilingInterface
925           .getTiledImplementation(
926               rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
927               sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
928           .front();
929   // All shapes are static and the data source is actually used. Rewrite into
930   // pad_tensor(subtensor(x)).
931   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
932   return success();
933 }
934 
935 namespace {
936 // The following are patterns for downscaling convolution ops with size-1
937 // window dimensions.
938 //
939 // Note that we'd eventually want to write such transformations in a generic
940 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
941 // and then turning back to named ops. But for now it's fine to have a few
942 // patterns matching special ops to get started.
943 
944 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
945 /// convolution ops.
946 struct DownscaleSizeOneWindowed2DConvolution final
947     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
948   DownscaleSizeOneWindowed2DConvolution(
949       MLIRContext *context,
950       LinalgTransformationFilter f = LinalgTransformationFilter(),
951       PatternBenefit benefit = 1)
952       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
953         filter(std::move(f)) {}
954 
955   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
956                                 PatternRewriter &rewriter) const override {
957     if (failed(filter.checkAndNotify(rewriter, convOp)))
958       return failure();
959     if (convOp.hasBufferSemantics())
960       return failure(); // To be implemented
961 
962     Value input = convOp.inputs().front();
963     Value kernel = convOp.inputs().back();
964     Value output = convOp.outputs().front();
965 
966     auto inputType = input.getType().dyn_cast<RankedTensorType>();
967     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
968     auto outputType = output.getType().dyn_cast<RankedTensorType>();
969 
970     auto kernelShape = kernelType.getShape();
971     auto outputShape = outputType.getShape();
972 
973     // Only handle the case where at least one of the window dimensions is
974     // of size 1. Other cases can rely on tiling to reduce to such cases.
975     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
976     int64_t ohSize = outputShape[1], owSize = outputShape[2];
977     bool removeH = (khSize == 1 && ohSize == 1);
978     bool removeW = (kwSize == 1 && owSize == 1);
979     if (!removeH && !removeW)
980       return failure();
981 
982     // Get new shapes and types for all operands by removing the size-1
983     // dimension.
984     using RTTBuilder = RankedTensorType::Builder;
985     RankedTensorType newInputType =
986         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
987     RankedTensorType newKernelType =
988         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
989     RankedTensorType newOutputType =
990         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
991 
992     // Rank-reduce operands.
993     Location loc = convOp.getLoc();
994     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
995         rewriter, loc, input, newInputType);
996     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
997         rewriter, loc, kernel, newKernelType);
998     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
999         rewriter, loc, output, newOutputType);
1000 
1001     // Rank-reduce strides and dilations too.
1002     // TODO: dropDim 1-liner helper.
1003     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1004     strides.erase(strides.begin() + (removeH ? 0 : 1));
1005     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1006 
1007     auto dilations =
1008         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1009     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1010     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1011 
1012     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1013         loc, newOutputType, ValueRange{newInput, newKernel},
1014         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1015 
1016     // Insert back.
1017     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1018         rewriter, loc, conv1DOp.getResult(0), output);
1019     rewriter.replaceOp(convOp, inserted);
1020 
1021     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1022     return success();
1023   };
1024 
1025 private:
1026   /// LinalgTransformMarker handles special attribute manipulations.
1027   LinalgTransformationFilter filter;
1028 };
1029 
1030 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1031 /// dimensions into 1-D depthwise convolution ops.
1032 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1033     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1034   DownscaleDepthwiseConv2DNhwcHwcOp(
1035       MLIRContext *context,
1036       LinalgTransformationFilter f = LinalgTransformationFilter(),
1037       PatternBenefit benefit = 1)
1038       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1039         filter(std::move(f)) {}
1040 
1041   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1042                                 PatternRewriter &rewriter) const override {
1043     if (failed(filter.checkAndNotify(rewriter, convOp)))
1044       return failure();
1045     if (convOp.hasBufferSemantics())
1046       return failure(); // To be implemented
1047 
1048     Value input = convOp.inputs().front();
1049     Value kernel = convOp.inputs().back();
1050     Value output = convOp.outputs().front();
1051 
1052     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1053     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1054     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1055 
1056     auto kernelShape = kernelType.getShape();
1057     auto outputShape = outputType.getShape();
1058 
1059     // Only handle the case where at least one of the window dimensions is
1060     // of size 1. Other cases can rely on tiling to reduce to such cases.
1061     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1062     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1063     bool removeH = (khSize == 1 && ohSize == 1);
1064     bool removeW = (kwSize == 1 && owSize == 1);
1065     if (!removeH && !removeW)
1066       return failure();
1067 
1068     // Get new shapes and types for all operands by removing the size-1
1069     // dimension.
1070     using RTTBuilder = RankedTensorType::Builder;
1071     RankedTensorType newInputType =
1072         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1073     RankedTensorType newKernelType =
1074         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1075     RankedTensorType newOutputType =
1076         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1077 
1078     // Rank-reduce operands.
1079     Location loc = convOp.getLoc();
1080     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1081         rewriter, loc, input, newInputType);
1082     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1083         rewriter, loc, kernel, newKernelType);
1084     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1085         rewriter, loc, output, newOutputType);
1086 
1087     // Rank-reduce strides and dilations too.
1088     // TODO: dropDim 1-liner helper.
1089     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1090     strides.erase(strides.begin() + (removeH ? 0 : 1));
1091     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1092 
1093     auto dilations =
1094         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1095     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1096     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1097 
1098     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1099         loc, newOutputType, ValueRange{newInput, newKernel},
1100         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1101 
1102     // Insert back.
1103     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1104         rewriter, loc, conv1DOp.getResult(0), output);
1105     rewriter.replaceOp(convOp, inserted);
1106 
1107     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1108     return success();
1109   };
1110 
1111 private:
1112   /// LinalgTransformMarker handles special attribute manipulations.
1113   LinalgTransformationFilter filter;
1114 };
1115 
1116 } // namespace
1117 
1118 void linalg::populateDecomposeConvolutionPatterns(
1119     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1120     PatternBenefit benefit) {
1121   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1122                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1123                                                   benefit);
1124 }
1125