1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StaticValueUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 #include <utility>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
58     Optional<StringAttr> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 replacement.getValue());
105   else
106     op->removeAttr(
107         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
108 }
109 
110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
111     Operation *op) const {
112   if (!replacement)
113     return false;
114   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
115                   .dyn_cast<StringAttr>();
116   return attr && attr == replacement.getValue();
117 }
118 
119 LinalgTilingOptions &
120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
121   assert(!tileSizeComputationFunction && "tile sizes already set");
122   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124     OpBuilder::InsertionGuard guard(b);
125     b.setInsertionPointToStart(
126         &op->getParentOfType<FuncOp>().getBody().front());
127     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129       return v;
130     }));
131   };
132   return *this;
133 }
134 
135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
136   assert(!tileSizeComputationFunction && "tile sizes already set");
137   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138     SmallVector<Value, 4> tileSizes;
139     auto linalgOp = dyn_cast<LinalgOp>(op);
140     if (!linalgOp)
141       return tileSizes;
142     Location loc = linalgOp.getLoc();
143     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144     AffineMap map = linalgOp.getShapesToLoopsMap();
145     if (!map)
146       return tileSizes;
147     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149     // size 0).
150     for (Value shapeSize : shapeSizes)
151       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152                               ? b.create<arith::ConstantIndexOp>(loc, 0)
153                               : b.create<arith::ConstantIndexOp>(loc, 1));
154     return tileSizes;
155   };
156   return *this;
157 }
158 
159 /// Helper function that tries to pad `opOperand`. Exit early for scalar
160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162 /// has a static shape. Set `result` to the result of the created tensor::PadOp
163 /// or and return success if the operand either has been padded to a static
164 /// shape or already had a static shape and failure otherwise.
165 static LogicalResult padOperandToSmallestStaticBoundingBox(
166     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167     const PaddingValueComputationFunction &paddingFunc,
168     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169   // Get the shape of the operand and check if it has a dynamic shape. Only
170   // return failure if the operand is not a scalar and has a dynamic shape.
171   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173 
174   // Cannot pad scalar operands.
175   if (shape.empty())
176     return success();
177 
178   // Cannot pad if the padding value is unknown.
179   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180   if (failed(paddingValue))
181     return failure(hasDynamicShape);
182 
183   // Cannot construct a static bounding box if the operand is not defined by an
184   // ExtractSliceOp.
185   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186   if (!sliceOp)
187     return failure(hasDynamicShape);
188 
189   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
190   llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
191 
192   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
193   SmallVector<int64_t> staticSizes;
194   staticSizes.reserve(shape.size());
195   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
196   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
197     // Skip dropped dimensions.
198     if (droppedDims.contains(en.index()))
199       continue;
200     // If the size is an attribute add it directly to `staticSizes`.
201     if (en.value().is<Attribute>()) {
202       staticSizes.push_back(
203           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
204       continue;
205     }
206     // Otherwise, try to compute a constant upper bound for the size value.
207     FailureOr<int64_t> upperBound =
208         getConstantUpperBoundForIndex(en.value().get<Value>());
209     if (failed(upperBound)) {
210       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
211       return failure();
212     }
213     staticSizes.push_back(upperBound.getValue());
214   }
215   assert(staticSizes.size() == shape.size() &&
216          "expect the dynamic and static ranks to match");
217 
218   // Pad the operand to the bounding box defined by `staticSizes`.
219   auto staticTensorType = RankedTensorType::get(
220       staticSizes, getElementTypeOrSelf(opOperand->get()));
221   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
222   result =
223       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
224                             opOperand->get(), paddingValue.getValue(), nofold);
225   return success();
226 }
227 
228 FailureOr<SmallVector<Value>>
229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
230                           const PaddingValueComputationFunction &paddingFunc,
231                           const PaddingNoFoldComputationFunction &nofoldFunc,
232                           LinalgOp &paddedOp) {
233   Location loc = opToPad->getLoc();
234 
235   // TODO: there are cases where we may still want to pad to larger sizes.
236   assert(opToPad.hasTensorSemantics() &&
237          "expected operation to have tensor semantics");
238 
239   OpBuilder::InsertionGuard g(b);
240   // Set IP after op because we also take the dims of the original output.
241   b.setInsertionPointAfter(opToPad);
242   // Make a copy of the shaped operands and update it.
243   SmallVector<Value> newOperands;
244   newOperands.reserve(opToPad.getNumInputsAndOutputs());
245   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
246     Value paddedOperand;
247     // If padding was requested but the shape cannot be bounded statically then
248     // the pattern fails to apply.
249     if (failed(padOperandToSmallestStaticBoundingBox(
250             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
251       return failure();
252     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
253   }
254 
255   SmallVector<SmallVector<Value>> reifiedResultShapes;
256   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
257                  .reifyResultShapes(b, reifiedResultShapes)))
258     return failure();
259   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
260          "expected same number of results");
261 
262   // Clone `opToPad` to operate on the statically padded shapes.
263   auto resultTensorTypes =
264       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
265   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
266 
267   // Recover the slice out of the new static results. This keeps the original
268   // linalg op around because it uses the dims of the original results.
269   SmallVector<Value> paddedSubviewResults;
270   paddedSubviewResults.reserve(opToPad->getNumResults());
271   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
272     Value paddedResult = en.value();
273     int64_t resultNumber = en.index();
274     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
275     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
276     SmallVector<OpFoldResult> sizes;
277     for (Value v : reifiedResultShapes[resultNumber])
278       sizes.push_back(getAsOpFoldResult(v));
279     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
280     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
281         loc, paddedResult, offsets, sizes, strides));
282   }
283   return paddedSubviewResults;
284 }
285 
286 /// Try to peel a loop `op` and return the new result.
287 // TODO: Add support for scf.parallel and affine.for loops.
288 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
289   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
290       .Case<scf::ForOp>([&](scf::ForOp forOp) {
291         scf::ForOp partialIteration;
292         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
293                                                       partialIteration)))
294           return partialIteration->getResults();
295         assert(!partialIteration && "expected that loop was not peeled");
296         return forOp->getResults();
297       })
298       .Default([&](Operation *op) { return op->getResults(); });
299 }
300 
301 /// Try to peel a TiledLoopOp and return the new result.
302 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
303                                       TiledLoopOp tiledLoop, int64_t idx) {
304   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
305          "requested peeling of non-existing loop");
306   TiledLoopOp result;
307   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
308     return result->getResults();
309   assert(!result && "expected that loop was not peeled");
310   return tiledLoop->getResults();
311 }
312 
313 /// Peel loops after tiling.
314 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
315                                      ArrayRef<int64_t> peeledLoops,
316                                      LinalgTilingLoopType loopType) {
317   for (int64_t loop : peeledLoops) {
318     assert(loop < static_cast<int64_t>(res.loops.size()) &&
319            "requested peeling of non-existing loop");
320     SmallVector<Value, 4> loopResults;
321     Operation *loopOp = res.loops[loop];
322     if (loopType == LinalgTilingLoopType::TiledLoops) {
323       assert(llvm::all_of(
324                  res.loops,
325                  [&](Operation *op) { return op == res.loops.front(); }) &&
326              "expected that all loop ops are the same TiledLoopOp");
327       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
328       assert(tiledLoopOp && "expected TiledLoopOp");
329       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
330     } else {
331       loopResults = peelLoop(rewriter, loopOp);
332     }
333 
334     // The result of the loop nest may change with peeling.
335     if (res.tensorResults.size() == loopOp->getNumResults() &&
336         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
337                    loopOp->getResults().begin()))
338       res.tensorResults = loopResults;
339   }
340 }
341 
342 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
343   if (tiledOp.loops.empty())
344     return tiledOp.op.getOperation()->getResults();
345   return tiledOp.loops.front()->getResults();
346 }
347 
348 static ValueRange
349 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
350   if (tiledAndFusedOp.fusedLoops.empty())
351     return tiledAndFusedOp.op.getOperation()->getResults();
352   return tiledAndFusedOp.fusedLoops.front()->getResults();
353 }
354 
355 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
356     StringRef opName, MLIRContext *context,
357     const LinalgDependenceGraph &dependenceGraph,
358     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
359     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
360     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
361     : RewritePattern(opName, benefit, context, {}),
362       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
363       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
364       fusedOpMarker(std::move(fusedOpMarker)),
365       originalOpMarker(std::move(originalOpMarker)) {}
366 
367 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
368     Operation *op, PatternRewriter &rewriter) const {
369   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
370   // TODO: remove hasIndexSemantics check once index ops are supported.
371   if (!linalgOp || linalgOp.hasIndexSemantics())
372     return failure();
373   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
374     return failure();
375 
376   DenseSet<Operation *> producers;
377   producers.insert(linalgOp);
378   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
379     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
380     // When looking at dependences into, indexingOp is always OpOperand. We
381     // could assert, but continue if this is not the case.
382     if (!operandNumber)
383       continue;
384     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
385       continue;
386     if (isa<LinalgOp>(dependence.getDependentOp()))
387       producers.insert(dependence.getDependentOp());
388   }
389 
390   SmallVector<LinalgOp, 1> fusionOps;
391   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
392        ++it) {
393     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
394     if (producerLinalgOp && producers.count(producerLinalgOp))
395       fusionOps.push_back(producerLinalgOp);
396   }
397   fusionOps.push_back(linalgOp);
398 
399   SmallVector<Value, 4> tileSizes =
400       tilingOptions.tileSizeComputationFunction(rewriter, op);
401   LinalgTilingOptions instanceTilingOptions = tilingOptions;
402   instanceTilingOptions.setTileSizes(tileSizes);
403   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
404       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
405   if (!tiledAndFusedOps)
406     return failure();
407 
408   // Tile the unfused loops;
409   SmallVector<Value, 4> unfusedLoopTileSizes;
410   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
411   for (const auto &tileSize : enumerate(tileSizes)) {
412     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
413       unfusedLoopTileSizes.push_back(zero);
414     else
415       unfusedLoopTileSizes.push_back(tileSize.value());
416   }
417   // Tile the loop only if there is a non-zero tile size.
418   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
419     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
420   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
421         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
422           return cst.value() != 0;
423         return true;
424       })) {
425     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
426     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
427     FailureOr<TiledLinalgOp> unfusedTiledOp =
428         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
429     if (failed(unfusedTiledOp))
430       return failure();
431     rewriter.replaceOp(tiledAndFusedOps->op,
432                        getTiledOpResult(unfusedTiledOp.getValue()));
433     tiledAndFusedOps->op = unfusedTiledOp->op;
434   }
435   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
436 
437   filter.replaceLinalgTransformationFilter(rewriter,
438                                            tiledAndFusedOps->op.getOperation());
439   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
440     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
441                                                     fusedOp.getOperation());
442   }
443   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
444     originalOpMarker.replaceLinalgTransformationFilter(
445         rewriter, origProducerOp.getOperation());
446   }
447   rewriter.updateRootInPlace(op, [&]() {
448     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
449   });
450   return success();
451 }
452 
453 /// Linalg tiling pattern.
454 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
455     MLIRContext *context, LinalgTilingOptions options,
456     LinalgTransformationFilter f, PatternBenefit benefit)
457     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458       filter(std::move(f)), options(std::move(options)) {}
459 
460 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
461     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
462     LinalgTransformationFilter f, PatternBenefit benefit)
463     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
464       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
465 
466 FailureOr<TiledLinalgOp>
467 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
468     LinalgOp op, PatternRewriter &rewriter) const {
469   if (failed(filter.checkAndNotify(rewriter, op)))
470     return failure();
471 
472   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
473   if (failed(res))
474     return failure();
475 
476   // Clear filter to stop recursive pattern application.
477   // This must be done here to properly propagate to peeling branches.
478   filter.replaceLinalgTransformationFilter(rewriter, res->op);
479 
480   // Peel the loops of the TiledLinalgOp.
481   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
482 
483   if (res->tensorResults.empty())
484     rewriter.eraseOp(op);
485   else
486     rewriter.replaceOp(op, res->tensorResults);
487 
488   return res;
489 }
490 
491 /// Linalg padding pattern.
492 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
493     MLIRContext *context, LinalgPaddingOptions options,
494     LinalgTransformationFilter f, PatternBenefit benefit)
495     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
496       filter(std::move(f)), options(std::move(options)) {}
497 
498 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
499     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
500     LinalgTransformationFilter f, PatternBenefit benefit)
501     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
502       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
503 
504 FailureOr<LinalgOp>
505 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
506     LinalgOp linalgOp, PatternRewriter &rewriter) const {
507   if (!linalgOp.hasTensorSemantics())
508     return failure();
509   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
510     return failure();
511 
512   // Pad the operation.
513   LinalgOp paddedOp;
514   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
515       rewriter, linalgOp, options.paddingValueComputationFunction,
516       options.paddingNoFoldComputationFunction, paddedOp);
517   if (failed(newResults))
518     return failure();
519 
520   // Compute the desired hoisting depths.
521   SmallVector<int64_t> depths;
522   if (options.paddingHoistComputationFunction) {
523     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
524       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
525   }
526 
527   // Hoist the padding.
528   for (const auto &en : enumerate(depths)) {
529     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
530     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
531     if (!padOp || en.value() == 0)
532       continue;
533     tensor::PadOp hoistedOp;
534     SmallVector<GenericOp> transposeOps;
535     SmallVector<int64_t> transposeVector =
536         options.paddingTransposeComputationFunction(opOperand);
537 
538     FailureOr<Value> newResult = hoistPaddingOnTensors(
539         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
540     if (failed(newResult))
541       continue;
542     rewriter.replaceOp(padOp, newResult.getValue());
543 
544     // Do not apply hoist padding to the newly introduced transpose operations.
545     for (GenericOp transposeOp : transposeOps)
546       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
547   }
548 
549   // Replace the original operation to pad.
550   rewriter.replaceOp(linalgOp, newResults.getValue());
551   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
552 
553   return paddedOp;
554 }
555 
556 /// Linalg tile and fuse tensor ops pattern.
557 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
558     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
559                                       LinalgTilingAndFusionOptions options,
560                                       LinalgTransformationFilter f,
561                                       PatternBenefit benefit)
562     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
563       filter(std::move(f)), options(std::move(options)) {}
564 
565 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
566     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
567                                       LinalgTilingAndFusionOptions options,
568                                       LinalgTransformationFilter f,
569                                       PatternBenefit benefit)
570     : RewritePattern(opName, benefit, context), filter(std::move(f)),
571       options(std::move(options)) {}
572 
573 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
574     Operation *op, PatternRewriter &rewriter) const {
575   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
576   if (!rootOp)
577     return failure();
578   if (failed(filter.checkAndNotify(rewriter, op)))
579     return failure();
580 
581   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
582   if (options.tileSizes.size() < rootOp.getNumLoops())
583     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
584 
585   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
586   if (!options.tileInterchange.empty() &&
587       options.tileInterchange.size() != options.tileSizes.size())
588     return rewriter.notifyMatchFailure(
589         op, "expect the number of tile sizes and interchange dims to match");
590 
591   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
592   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
593                                      options.tileSizes.begin() +
594                                          rootOp.getNumLoops());
595   SmallVector<int64_t> rootInterchange =
596       options.tileInterchange.empty()
597           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
598           : SmallVector<int64_t>(options.tileInterchange.begin(),
599                                  options.tileInterchange.begin() +
600                                      rootOp.getNumLoops());
601 
602   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
603   // It has to be a permutation since the tiling cannot tile the same loop
604   // dimension multiple times.
605   if (!isPermutation(rootInterchange))
606     return rewriter.notifyMatchFailure(
607         op, "expect the tile interchange permutes the root loops");
608 
609   // Tile `rootOp` and fuse its producers.
610   FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
611       rewriter, rootOp, rootTileSizes, rootInterchange);
612   if (failed(tileLoopNest))
613     return rewriter.notifyMatchFailure(
614         op, "tileConsumerAndFuseProducers failed unexpectedly");
615 
616   // Replace all uses of the tiled loop operation.
617   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
618 
619   // Apply the filter if specified.
620   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
621     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
622   return failure();
623 }
624 
625 /// Linalg generic interchange pattern.
626 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
627     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
628     LinalgTransformationFilter f, PatternBenefit benefit)
629     : OpRewritePattern(context, benefit), filter(std::move(f)),
630       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
631 
632 FailureOr<GenericOp>
633 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
634     GenericOp genericOp, PatternRewriter &rewriter) const {
635   if (failed(filter.checkAndNotify(rewriter, genericOp)))
636     return failure();
637 
638   FailureOr<GenericOp> transformedOp =
639       interchangeGenericOp(rewriter, genericOp, interchangeVector);
640   if (failed(transformedOp))
641     return failure();
642 
643   // New filter if specified.
644   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
645   return transformedOp;
646 }
647 
648 /// Linalg generalization pattern.
649 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
650     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
651     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
652       filter(std::move(f)) {}
653 
654 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
655     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
656     PatternBenefit benefit)
657     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
658       filter(f.addOpNameFilter(opName)) {}
659 
660 FailureOr<GenericOp>
661 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
662     LinalgOp linalgOp, PatternRewriter &rewriter) const {
663   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
664     return failure();
665   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
666   if (failed(genericOp))
667     return failure();
668   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
669   return genericOp;
670 }
671 
672 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
673     MLIRContext *context, LinalgTransformationFilter f,
674     LinalgPromotionOptions options, PatternBenefit benefit)
675     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
676       filter(std::move(f)), options(std::move(options)) {}
677 
678 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
679     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
680     LinalgTransformationFilter f, PatternBenefit benefit)
681     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
682       options(std::move(options)) {}
683 
684 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
685     Operation *op, PatternRewriter &rewriter) const {
686   if (failed(filter.checkAndNotify(rewriter, op)))
687     return failure();
688   if (failed(promoteSubviewsPrecondition(op, options)))
689     return failure();
690 
691   // TODO: We cannot use root update here. This pattern is creating other ops,
692   // so if the promotion fails, those need to be cleaned up, which doesnt seem
693   // to be happening here. So to fail properly, we should be cloning the op and
694   // deleting the previous op. This needs more investigation.
695   rewriter.startRootUpdate(op);
696   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
697   if (!promotedOp) {
698     rewriter.cancelRootUpdate(op);
699     return op->emitError("subview promotion failed");
700   }
701   rewriter.finalizeRootUpdate(op);
702   filter.replaceLinalgTransformationFilter(rewriter, op);
703   return success();
704 }
705 
706 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
707     MLIRContext *context, LinalgTransformationFilter f,
708     LinalgVectorizationOptions options, PatternBenefit benefit)
709     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
710       filter(std::move(f)) {}
711 
712 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
713     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
714     LinalgTransformationFilter f, PatternBenefit benefit)
715     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
716       filter(f.addOpNameFilter(opName)) {}
717 
718 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
719     LinalgOp linalgOp, PatternRewriter &rewriter) const {
720   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
721     return failure();
722   return vectorize(rewriter, linalgOp);
723 }
724 
725 LogicalResult mlir::linalg::applyStagedPatterns(
726     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
727     const FrozenRewritePatternSet &stage2Patterns,
728     function_ref<LogicalResult(Operation *)> stage3Lambda) {
729   unsigned iteration = 0;
730   (void)iteration;
731   for (const auto &patterns : stage1Patterns) {
732     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
733                       << *op);
734     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
735       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
736       return failure();
737     }
738     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
739                       << *op);
740     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
741       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
742       return failure();
743     }
744     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
745                       << *op);
746     if (stage3Lambda) {
747       if (failed(stage3Lambda(op)))
748         return failure();
749       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
750                         << *op);
751     }
752   }
753   return success();
754 }
755 
756 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
757   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
758 }
759 
760 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
761 /// initialize with pad_val) and GenericOp (to copy contents).
762 LogicalResult
763 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
764                                             PatternRewriter &rewriter) const {
765 
766   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
767   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
768 
769   // Bail on non-static shapes.
770   if (!inputShapedType.hasStaticShape())
771     return failure();
772   if (!resultShapedType.hasStaticShape())
773     return failure();
774 
775   // Only support padding with a constant for now, i.e. either:
776   //   1. A BBarg from a different block.
777   //   2. A value defined outside of the current block.
778   Block &block = padOp.region().front();
779   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
780   Value padValue = yieldOp.value();
781   Operation *definingOp = padValue.getDefiningOp();
782   if (definingOp && definingOp->getBlock() == &block)
783     return failure();
784   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
785     return failure();
786 
787   // Create tensor with the padded shape
788   Location loc = padOp.getLoc();
789   SmallVector<Value> indices(resultShapedType.getRank(),
790                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
791   Value initTensor = rewriter.create<InitTensorOp>(
792       loc, resultShapedType.getShape(), resultShapedType.getElementType());
793 
794   // Initialize tensor with the pad value
795   Value tmpTensor =
796       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
797 
798   // Copy original contents into new tensor
799   // Uses linalg.generic, but could be done with tensor.insert_slice
800   SmallVector<AffineExpr, 4> outputExprs;
801   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
802     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
803                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
804   }
805 
806   SmallVector<AffineMap, 2> transferMaps = {
807       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
808       AffineMap::get(resultShapedType.getRank(),
809                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
810 
811   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
812       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
813       getNParallelLoopsAttrs(resultShapedType.getRank()),
814       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
815         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
816       });
817 
818   return success();
819 }
820 
821 /// Filling `dest` using FillOp constant padding value if possible.
822 /// Otherwise, generate a tensor::GenerateOp.
823 Value GeneralizePadOpPattern::createFillOrGenerateOp(
824     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
825     const SmallVector<Value> &dynSizes) const {
826   auto padValue = padOp.getConstantPaddingValue();
827   if (padValue)
828     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
829 
830   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
831   auto generateOp = rewriter.create<tensor::GenerateOp>(
832       padOp.getLoc(), padOp.getResultType(), dynSizes);
833   // Copy region to new op.
834   BlockAndValueMapping bvm;
835   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
836   return generateOp;
837 }
838 
839 LogicalResult
840 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
841                                         PatternRewriter &rewriter) const {
842   // Given an OpFoldResult, return an index-typed value.
843   auto getIdxValue = [&](OpFoldResult ofr) {
844     if (auto val = ofr.dyn_cast<Value>())
845       return val;
846     return rewriter
847         .create<arith::ConstantIndexOp>(
848             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
849         .getResult();
850   };
851 
852   auto resultType = padOp.getResultType();
853   // Compute size of InitTensorOp. Any combination of static/dynamic is
854   // supported.
855   SmallVector<Value> dynSizes;
856   SmallVector<int64_t> staticSizes;
857   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
858     if (resultType.isDynamicDim(dim)) {
859       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
860                                                           padOp.source(), dim);
861       // Add low and high padding value.
862       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
863           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
864       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
865           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
866       dynSizes.push_back(plusHigh);
867     }
868     staticSizes.push_back(resultType.getDimSize(dim));
869   }
870 
871   // Init tensor and fill it with padding.
872   Value init = rewriter.create<InitTensorOp>(
873       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
874   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
875 
876   // Try optimize the copy of source.
877   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
878     return success();
879 
880   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
881   // for copying the PadOp source.
882   auto sourceType = padOp.getSourceType();
883   // Compute size of source of tensor::PadOp.
884   SmallVector<OpFoldResult> srcSizes;
885   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
886     if (sourceType.isDynamicDim(dim)) {
887       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
888           padOp.getLoc(), padOp.source(), dim));
889     } else {
890       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
891     }
892   }
893   // Strides of InsertSliceOp are all 1.
894   SmallVector<OpFoldResult> strides(sourceType.getRank(),
895                                     rewriter.getIndexAttr(1));
896   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
897       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
898 
899   return success();
900 }
901 
902 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
903     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
904   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
905   if (!padOp)
906     return failure();
907   // Only unit stride supported.
908   if (!sliceOp.hasUnitStride())
909     return failure();
910 
911   TilingInterface tilingInterface =
912       dyn_cast<TilingInterface>(padOp.getOperation());
913   Operation *tiledPadOp =
914       tilingInterface
915           .getTiledImplementation(
916               rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
917               sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
918           .front();
919   // All shapes are static and the data source is actually used. Rewrite into
920   // pad_tensor(subtensor(x)).
921   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
922   return success();
923 }
924 
925 namespace {
926 // The following are patterns for downscaling convolution ops with size-1
927 // window dimensions.
928 //
929 // Note that we'd eventually want to write such transformations in a generic
930 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
931 // and then turning back to named ops. But for now it's fine to have a few
932 // patterns matching special ops to get started.
933 
934 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
935 /// convolution ops.
936 struct DownscaleSizeOneWindowed2DConvolution final
937     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
938   DownscaleSizeOneWindowed2DConvolution(
939       MLIRContext *context,
940       LinalgTransformationFilter f = LinalgTransformationFilter(),
941       PatternBenefit benefit = 1)
942       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
943         filter(std::move(f)) {}
944 
945   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
946                                 PatternRewriter &rewriter) const override {
947     if (failed(filter.checkAndNotify(rewriter, convOp)))
948       return failure();
949     if (convOp.hasBufferSemantics())
950       return failure(); // To be implemented
951 
952     Value input = convOp.inputs().front();
953     Value kernel = convOp.inputs().back();
954     Value output = convOp.outputs().front();
955 
956     auto inputType = input.getType().dyn_cast<RankedTensorType>();
957     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
958     auto outputType = output.getType().dyn_cast<RankedTensorType>();
959 
960     auto kernelShape = kernelType.getShape();
961     auto outputShape = outputType.getShape();
962 
963     // Only handle the case where at least one of the window dimensions is
964     // of size 1. Other cases can rely on tiling to reduce to such cases.
965     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
966     int64_t ohSize = outputShape[1], owSize = outputShape[2];
967     bool removeH = (khSize == 1 && ohSize == 1);
968     bool removeW = (kwSize == 1 && owSize == 1);
969     if (!removeH && !removeW)
970       return failure();
971 
972     // Get new shapes and types for all operands by removing the size-1
973     // dimension.
974     using RTTBuilder = RankedTensorType::Builder;
975     RankedTensorType newInputType =
976         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
977     RankedTensorType newKernelType =
978         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
979     RankedTensorType newOutputType =
980         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
981 
982     // Rank-reduce operands.
983     Location loc = convOp.getLoc();
984     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
985         rewriter, loc, input, newInputType);
986     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
987         rewriter, loc, kernel, newKernelType);
988     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
989         rewriter, loc, output, newOutputType);
990 
991     // Rank-reduce strides and dilations too.
992     // TODO: dropDim 1-liner helper.
993     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
994     strides.erase(strides.begin() + (removeH ? 0 : 1));
995     auto stridesAttr = rewriter.getI64VectorAttr(strides);
996 
997     auto dilations =
998         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
999     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1000     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1001 
1002     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1003         loc, newOutputType, ValueRange{newInput, newKernel},
1004         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1005 
1006     // Insert back.
1007     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1008         rewriter, loc, conv1DOp.getResult(0), output);
1009     rewriter.replaceOp(convOp, inserted);
1010 
1011     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1012     return success();
1013   };
1014 
1015 private:
1016   /// LinalgTransformMarker handles special attribute manipulations.
1017   LinalgTransformationFilter filter;
1018 };
1019 
1020 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1021 /// dimensions into 1-D depthwise convolution ops.
1022 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1023     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1024   DownscaleDepthwiseConv2DNhwcHwcOp(
1025       MLIRContext *context,
1026       LinalgTransformationFilter f = LinalgTransformationFilter(),
1027       PatternBenefit benefit = 1)
1028       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1029         filter(std::move(f)) {}
1030 
1031   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1032                                 PatternRewriter &rewriter) const override {
1033     if (failed(filter.checkAndNotify(rewriter, convOp)))
1034       return failure();
1035     if (convOp.hasBufferSemantics())
1036       return failure(); // To be implemented
1037 
1038     Value input = convOp.inputs().front();
1039     Value kernel = convOp.inputs().back();
1040     Value output = convOp.outputs().front();
1041 
1042     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1043     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1044     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1045 
1046     auto kernelShape = kernelType.getShape();
1047     auto outputShape = outputType.getShape();
1048 
1049     // Only handle the case where at least one of the window dimensions is
1050     // of size 1. Other cases can rely on tiling to reduce to such cases.
1051     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1052     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1053     bool removeH = (khSize == 1 && ohSize == 1);
1054     bool removeW = (kwSize == 1 && owSize == 1);
1055     if (!removeH && !removeW)
1056       return failure();
1057 
1058     // Get new shapes and types for all operands by removing the size-1
1059     // dimension.
1060     using RTTBuilder = RankedTensorType::Builder;
1061     RankedTensorType newInputType =
1062         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1063     RankedTensorType newKernelType =
1064         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1065     RankedTensorType newOutputType =
1066         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1067 
1068     // Rank-reduce operands.
1069     Location loc = convOp.getLoc();
1070     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1071         rewriter, loc, input, newInputType);
1072     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1073         rewriter, loc, kernel, newKernelType);
1074     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1075         rewriter, loc, output, newOutputType);
1076 
1077     // Rank-reduce strides and dilations too.
1078     // TODO: dropDim 1-liner helper.
1079     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1080     strides.erase(strides.begin() + (removeH ? 0 : 1));
1081     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1082 
1083     auto dilations =
1084         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1085     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1086     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1087 
1088     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1089         loc, newOutputType, ValueRange{newInput, newKernel},
1090         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1091 
1092     // Insert back.
1093     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1094         rewriter, loc, conv1DOp.getResult(0), output);
1095     rewriter.replaceOp(convOp, inserted);
1096 
1097     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1098     return success();
1099   };
1100 
1101 private:
1102   /// LinalgTransformMarker handles special attribute manipulations.
1103   LinalgTransformationFilter filter;
1104 };
1105 
1106 } // namespace
1107 
1108 void linalg::populateDecomposeConvolutionPatterns(
1109     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1110     PatternBenefit benefit) {
1111   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1112                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1113                                                   benefit);
1114 }
1115